commit ae18c0ea3bc5c898c31381f7adcd506460f967a0 Author: Steve Dudenhoeffer Date: Fri Oct 6 01:58:58 2023 -0400 initial commit diff --git a/cmd/cmd.go b/cmd/cmd.go new file mode 100644 index 0000000..e885171 --- /dev/null +++ b/cmd/cmd.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "github.com/sduden/simpleproxy" + "os" +) + +// simpleproxy is a simple proxy server. +// the first argument passed to the program is the address of the server to proxy to. +// the second argument is the address to listen on. +// if not specified, listens on 0.0.0.0:8080 +func main() { + if len(os.Args) < 2 { + fmt.Printf("Usage: %s [listen]\n", os.Args[0]) + os.Exit(1) + } + + server := os.Args[1] + + listen := ":8080" + + if len(os.Args) > 2 { + listen = os.Args[2] + } + + fmt.Printf("Listening on %s, proxying to %s\n", listen, server) + + proxy := simpleproxy.NewProxy(listen, server) + proxy.ListenAndServe() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a970944 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/sduden/simpleproxy + +go 1.21 diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..d5bafee --- /dev/null +++ b/proxy.go @@ -0,0 +1,82 @@ +package simpleproxy + +import ( + "io" + "net/http" + "net/url" +) + +type Proxy struct { + ListenAddr string + ServerAddr string + + u *url.URL +} + +// NewProxy creates a new proxy server. +func NewProxy(listenAddr, serverAddr string) *Proxy { + u, err := url.Parse(serverAddr) + if err != nil { + panic(err) + } + + return &Proxy{ + ListenAddr: listenAddr, + ServerAddr: serverAddr, + u: u, + } +} + +// ListenAndServe starts the proxy server. +func (p *Proxy) ListenAndServe() error { + srv := &http.Server{ + Addr: p.ListenAddr, + Handler: p, + } + + return srv.ListenAndServe() + +} + +// ServeHTTP is the main handler for the proxy server. +// all requests here should be proxied to the server. +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // all requests should be proxied to the server. + req, err := http.NewRequest(r.Method, p.ServerAddr+r.URL.String(), r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // copy headers + for k, v := range r.Header { + req.Header[k] = v + } + + // make the request + resp, err := http.DefaultClient.Do(req) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + mirrorResponse(w, resp) + +} + +func mirrorResponse(w http.ResponseWriter, resp *http.Response) { + // Copy all headers from the original response + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Copy the status code from the original response + w.WriteHeader(resp.StatusCode) + + // Copy the response body + io.Copy(w, resp.Body) + resp.Body.Close() +}