diff --git a/Makefile b/Makefile index 030b0042..f567bc54 100644 --- a/Makefile +++ b/Makefile @@ -19,21 +19,17 @@ all: mac linux simple-responder clean: rm -rf $(BUILD_DIR) -proxy/ui_dist/placeholder.txt: - mkdir -p proxy/ui_dist - touch $@ - # use cached test results while developing -test-dev: proxy/ui_dist/placeholder.txt - go test -short ./proxy/... ./internal/... - staticcheck ./proxy/... ./internal/... || true +test-dev: + go test -short ./... + staticcheck ./... || true -test: proxy/ui_dist/placeholder.txt - go test -short -count=1 ./proxy/... ./internal/... +test: + go test -short -count=1 ./internal/... # for CI - full test (takes longer) -test-all: proxy/ui_dist/placeholder.txt - go test -race -count=1 ./proxy/... ./internal/... +test-all: + go test -race -count=1 ./internal/... ui/node_modules: cd ui-svelte && npm install @@ -64,7 +60,7 @@ windows: ui @echo "Building Windows binary..." GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe -# for testing proxy.Process +# for testing with real external processes simple-responder: @echo "Building simple responder" GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go diff --git a/internal/perf/monitor.go b/internal/perf/monitor.go index 1e2c1c77..fd3f0a88 100644 --- a/internal/perf/monitor.go +++ b/internal/perf/monitor.go @@ -12,7 +12,7 @@ import ( ) var ( - ErrNotImplemented = errors.New("Not Implemented") + ErrNotImplemented = errors.New("not implemented") ErrNoGpuTool = errors.New("no GPU monitoring tool available") ) diff --git a/internal/router/peer.go b/internal/router/peer.go index a017cca1..fb17068c 100644 --- a/internal/router/peer.go +++ b/internal/router/peer.go @@ -62,13 +62,12 @@ func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) { IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second, } - reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL) - reverseProxy.Transport = peerTransport - - originalDirector := reverseProxy.Director - reverseProxy.Director = func(req *http.Request) { - originalDirector(req) - req.Host = req.URL.Host + reverseProxy := &httputil.ReverseProxy{ + Transport: peerTransport, + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(peer.ProxyURL) + r.Out.Host = r.Out.URL.Host + }, } reverseProxy.ModifyResponse = func(resp *http.Response) error { diff --git a/internal/server/extras_test.go b/internal/server/extras_test.go index f881ce4c..4d8a787c 100644 --- a/internal/server/extras_test.go +++ b/internal/server/extras_test.go @@ -1,10 +1,12 @@ package server import ( + "bufio" "bytes" "compress/flate" "compress/gzip" "io" + "net" "net/http" "net/http/httptest" "testing" @@ -75,6 +77,55 @@ func TestServer_BodyCopier_Flush(t *testing.T) { } } +// hijackRecorder is an httptest.ResponseRecorder that also implements +// http.Hijacker, returning a pipe so Hijack forwarding can be exercised. +type hijackRecorder struct { + *httptest.ResponseRecorder + conn net.Conn +} + +func (h *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil +} + +func TestServer_BodyCopier_Hijack(t *testing.T) { + t.Run("forwards to underlying hijacker", func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + bc := newBodyCopier(&hijackRecorder{httptest.NewRecorder(), server}) + conn, _, err := bc.Hijack() + if err != nil { + t.Fatalf("Hijack: %v", err) + } + if conn != server { + t.Errorf("Hijack returned unexpected conn") + } + }) + + t.Run("errors when underlying writer is not a hijacker", func(t *testing.T) { + bc := newBodyCopier(httptest.NewRecorder()) + if _, _, err := bc.Hijack(); err == nil { + t.Error("expected error hijacking a non-Hijacker ResponseWriter") + } + }) +} + +func TestServer_BodyCopier_SkipsBufferingOnUpgrade(t *testing.T) { + rec := httptest.NewRecorder() + bc := newBodyCopier(rec) + bc.WriteHeader(http.StatusSwitchingProtocols) + bc.Write([]byte("websocket frame bytes")) + + if bc.body.Len() != 0 { + t.Errorf("upgrade body buffered = %q, want empty", bc.body.Bytes()) + } + if got := rec.Body.String(); got != "websocket frame bytes" { + t.Errorf("client body = %q, want %q", got, "websocket frame bytes") + } +} + func TestServer_HeaderMapAndRedact(t *testing.T) { h := http.Header{ "Content-Type": {"application/json"}, diff --git a/internal/server/log.go b/internal/server/log.go index 2a7d9d5c..a41d5952 100644 --- a/internal/server/log.go +++ b/internal/server/log.go @@ -1,6 +1,7 @@ package server import ( + "bufio" "context" "fmt" "io" @@ -150,7 +151,8 @@ var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics" // statusRecorder wraps an http.ResponseWriter to capture the response status // code and the number of body bytes written, so the access log can report -// them. Flush is forwarded so streaming handlers (SSE) still work. +// them. Flush is forwarded so streaming handlers (SSE) still work, and Hijack +// is forwarded so httputil.ReverseProxy can upgrade websocket connections. type statusRecorder struct { http.ResponseWriter status int @@ -174,6 +176,15 @@ func (sr *statusRecorder) Flush() { } } +// Hijack forwards to the underlying ResponseWriter so httputil.ReverseProxy can +// take over the connection for websocket upgrades. +func (sr *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := sr.ResponseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking") +} + // clientIP resolves the originating client address, preferring proxy headers // over the raw connection address. func clientIP(r *http.Request) string { diff --git a/internal/server/log_test.go b/internal/server/log_test.go index 1af8cc4f..2adf823d 100644 --- a/internal/server/log_test.go +++ b/internal/server/log_test.go @@ -1,11 +1,16 @@ package server import ( + "bufio" "io" + "net" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" "strings" "testing" + "time" "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" @@ -135,3 +140,103 @@ func TestServer_RequestLogMiddleware(t *testing.T) { }) } } + +// TestServer_RequestLogMiddleware_WebSocketUpgrade verifies that the access-log +// middleware (which wraps responses in statusRecorder) does not break websocket +// upgrades proxied through httputil.ReverseProxy. ReverseProxy requires the +// ResponseWriter to implement http.Hijacker to take over the connection; if +// statusRecorder does not forward Hijack, the upgrade is refused with 502. +func TestServer_RequestLogMiddleware_WebSocketUpgrade(t *testing.T) { + // Upstream: complete the upgrade handshake then echo bytes back. This + // stands in for an upstream that speaks websocket; ReverseProxy only cares + // about the 101 response and then copies raw bytes both ways. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Errorf("upstream ResponseWriter is not an http.Hijacker") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Errorf("upstream hijack: %v", err) + return + } + defer conn.Close() + brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + brw.Flush() + // Echo whatever the client sends. + buf := make([]byte, 64) + n, err := brw.Read(buf) + if err != nil { + return + } + brw.Write(buf[:n]) + brw.Flush() + })) + defer upstream.Close() + + upstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("parse upstream URL: %v", err) + } + + // Front server: ReverseProxy wrapped in the access-log middleware, which is + // the production statusRecorder-wrapped path. + proxy := httputil.NewSingleHostReverseProxy(upstreamURL) + mw := CreateRequestLogMiddleware(logmon.NewWriter(io.Discard)) + front := httptest.NewServer(mw(proxy)) + defer front.Close() + + frontURL, err := url.Parse(front.URL) + if err != nil { + t.Fatalf("parse front URL: %v", err) + } + + conn, err := net.DialTimeout("tcp", frontURL.Host, 5*time.Second) + if err != nil { + t.Fatalf("dial front: %v", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(5 * time.Second)) + + req := "GET / HTTP/1.1\r\n" + + "Host: " + frontURL.Host + "\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: websocket\r\n" + + "\r\n" + if _, err := conn.Write([]byte(req)); err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + br := bufio.NewReader(conn) + statusLine, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + t.Fatalf("websocket upgrade failed: status line = %q, want 101 Switching Protocols", strings.TrimSpace(statusLine)) + } + + // Drain the rest of the response headers. + for { + line, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read headers: %v", err) + } + if strings.TrimSpace(line) == "" { + break + } + } + + // Verify bytes flow through the hijacked connection. + if _, err := conn.Write([]byte("ping")); err != nil { + t.Fatalf("write payload: %v", err) + } + echo := make([]byte, 4) + if _, err := io.ReadFull(br, echo); err != nil { + t.Fatalf("read echo: %v", err) + } + if string(echo) != "ping" { + t.Errorf("echo = %q, want %q", echo, "ping") + } +} diff --git a/internal/server/metrics.go b/internal/server/metrics.go index 71ef1f84..733f9c1f 100644 --- a/internal/server/metrics.go +++ b/internal/server/metrics.go @@ -1,12 +1,14 @@ package server import ( + "bufio" "bytes" "compress/flate" "compress/gzip" "encoding/json" "fmt" "io" + "net" "net/http" "strings" "sync" @@ -427,6 +429,12 @@ func (w *responseBodyCopier) Write(b []byte) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } + // On a protocol upgrade (e.g. websocket) the body is raw framed data, not a + // metrics-parseable response, so write straight to the client without + // buffering a copy we can't use. + if w.status == http.StatusSwitchingProtocols { + return w.ResponseWriter.Write(b) + } return w.tee.Write(b) } @@ -446,5 +454,14 @@ func (w *responseBodyCopier) Flush() { } } +// Hijack forwards to the underlying writer so httputil.ReverseProxy can take +// over the connection for websocket upgrades. +func (w *responseBodyCopier) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := w.ResponseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking") +} + func (w *responseBodyCopier) Status() int { return w.status } func (w *responseBodyCopier) StartTime() time.Time { return w.start }