Makefile,internal: fix websocket regression and other small things (#830)

- fix websocket regression and add test to prevent in the future
- fix staticheck errors
- remove proxy package remnants from Makefile 

fix #829
This commit is contained in:
Benson Wong
2026-06-09 21:37:53 -07:00
committed by GitHub
parent 44e1501e81
commit 0cfe5a6639
7 changed files with 200 additions and 21 deletions
+8 -12
View File
@@ -19,21 +19,17 @@ all: mac linux simple-responder
clean: clean:
rm -rf $(BUILD_DIR) rm -rf $(BUILD_DIR)
proxy/ui_dist/placeholder.txt:
mkdir -p proxy/ui_dist
touch $@
# use cached test results while developing # use cached test results while developing
test-dev: proxy/ui_dist/placeholder.txt test-dev:
go test -short ./proxy/... ./internal/... go test -short ./...
staticcheck ./proxy/... ./internal/... || true staticcheck ./... || true
test: proxy/ui_dist/placeholder.txt test:
go test -short -count=1 ./proxy/... ./internal/... go test -short -count=1 ./internal/...
# for CI - full test (takes longer) # for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt test-all:
go test -race -count=1 ./proxy/... ./internal/... go test -race -count=1 ./internal/...
ui/node_modules: ui/node_modules:
cd ui-svelte && npm install cd ui-svelte && npm install
@@ -64,7 +60,7 @@ windows: ui
@echo "Building Windows binary..." @echo "Building Windows binary..."
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
# for testing proxy.Process # for testing with real external processes
simple-responder: simple-responder:
@echo "Building simple responder" @echo "Building simple responder"
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
+1 -1
View File
@@ -12,7 +12,7 @@ import (
) )
var ( var (
ErrNotImplemented = errors.New("Not Implemented") ErrNotImplemented = errors.New("not implemented")
ErrNoGpuTool = errors.New("no GPU monitoring tool available") ErrNoGpuTool = errors.New("no GPU monitoring tool available")
) )
+6 -7
View File
@@ -62,13 +62,12 @@ func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second, IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
} }
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL) reverseProxy := &httputil.ReverseProxy{
reverseProxy.Transport = peerTransport Transport: peerTransport,
Rewrite: func(r *httputil.ProxyRequest) {
originalDirector := reverseProxy.Director r.SetURL(peer.ProxyURL)
reverseProxy.Director = func(req *http.Request) { r.Out.Host = r.Out.URL.Host
originalDirector(req) },
req.Host = req.URL.Host
} }
reverseProxy.ModifyResponse = func(resp *http.Response) error { reverseProxy.ModifyResponse = func(resp *http.Response) error {
+51
View File
@@ -1,10 +1,12 @@
package server package server
import ( import (
"bufio"
"bytes" "bytes"
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@@ -75,6 +77,55 @@ func TestServer_BodyCopier_Flush(t *testing.T) {
} }
} }
// hijackRecorder is an httptest.ResponseRecorder that also implements
// http.Hijacker, returning a pipe so Hijack forwarding can be exercised.
type hijackRecorder struct {
*httptest.ResponseRecorder
conn net.Conn
}
func (h *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
}
func TestServer_BodyCopier_Hijack(t *testing.T) {
t.Run("forwards to underlying hijacker", func(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
bc := newBodyCopier(&hijackRecorder{httptest.NewRecorder(), server})
conn, _, err := bc.Hijack()
if err != nil {
t.Fatalf("Hijack: %v", err)
}
if conn != server {
t.Errorf("Hijack returned unexpected conn")
}
})
t.Run("errors when underlying writer is not a hijacker", func(t *testing.T) {
bc := newBodyCopier(httptest.NewRecorder())
if _, _, err := bc.Hijack(); err == nil {
t.Error("expected error hijacking a non-Hijacker ResponseWriter")
}
})
}
func TestServer_BodyCopier_SkipsBufferingOnUpgrade(t *testing.T) {
rec := httptest.NewRecorder()
bc := newBodyCopier(rec)
bc.WriteHeader(http.StatusSwitchingProtocols)
bc.Write([]byte("websocket frame bytes"))
if bc.body.Len() != 0 {
t.Errorf("upgrade body buffered = %q, want empty", bc.body.Bytes())
}
if got := rec.Body.String(); got != "websocket frame bytes" {
t.Errorf("client body = %q, want %q", got, "websocket frame bytes")
}
}
func TestServer_HeaderMapAndRedact(t *testing.T) { func TestServer_HeaderMapAndRedact(t *testing.T) {
h := http.Header{ h := http.Header{
"Content-Type": {"application/json"}, "Content-Type": {"application/json"},
+12 -1
View File
@@ -1,6 +1,7 @@
package server package server
import ( import (
"bufio"
"context" "context"
"fmt" "fmt"
"io" "io"
@@ -150,7 +151,8 @@ var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"
// statusRecorder wraps an http.ResponseWriter to capture the response status // statusRecorder wraps an http.ResponseWriter to capture the response status
// code and the number of body bytes written, so the access log can report // code and the number of body bytes written, so the access log can report
// them. Flush is forwarded so streaming handlers (SSE) still work. // them. Flush is forwarded so streaming handlers (SSE) still work, and Hijack
// is forwarded so httputil.ReverseProxy can upgrade websocket connections.
type statusRecorder struct { type statusRecorder struct {
http.ResponseWriter http.ResponseWriter
status int status int
@@ -174,6 +176,15 @@ func (sr *statusRecorder) Flush() {
} }
} }
// Hijack forwards to the underlying ResponseWriter so httputil.ReverseProxy can
// take over the connection for websocket upgrades.
func (sr *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := sr.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
}
// clientIP resolves the originating client address, preferring proxy headers // clientIP resolves the originating client address, preferring proxy headers
// over the raw connection address. // over the raw connection address.
func clientIP(r *http.Request) string { func clientIP(r *http.Request) string {
+105
View File
@@ -1,11 +1,16 @@
package server package server
import ( import (
"bufio"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil"
"net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
@@ -135,3 +140,103 @@ func TestServer_RequestLogMiddleware(t *testing.T) {
}) })
} }
} }
// TestServer_RequestLogMiddleware_WebSocketUpgrade verifies that the access-log
// middleware (which wraps responses in statusRecorder) does not break websocket
// upgrades proxied through httputil.ReverseProxy. ReverseProxy requires the
// ResponseWriter to implement http.Hijacker to take over the connection; if
// statusRecorder does not forward Hijack, the upgrade is refused with 502.
func TestServer_RequestLogMiddleware_WebSocketUpgrade(t *testing.T) {
// Upstream: complete the upgrade handshake then echo bytes back. This
// stands in for an upstream that speaks websocket; ReverseProxy only cares
// about the 101 response and then copies raw bytes both ways.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
t.Errorf("upstream ResponseWriter is not an http.Hijacker")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
t.Errorf("upstream hijack: %v", err)
return
}
defer conn.Close()
brw.WriteString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
brw.Flush()
// Echo whatever the client sends.
buf := make([]byte, 64)
n, err := brw.Read(buf)
if err != nil {
return
}
brw.Write(buf[:n])
brw.Flush()
}))
defer upstream.Close()
upstreamURL, err := url.Parse(upstream.URL)
if err != nil {
t.Fatalf("parse upstream URL: %v", err)
}
// Front server: ReverseProxy wrapped in the access-log middleware, which is
// the production statusRecorder-wrapped path.
proxy := httputil.NewSingleHostReverseProxy(upstreamURL)
mw := CreateRequestLogMiddleware(logmon.NewWriter(io.Discard))
front := httptest.NewServer(mw(proxy))
defer front.Close()
frontURL, err := url.Parse(front.URL)
if err != nil {
t.Fatalf("parse front URL: %v", err)
}
conn, err := net.DialTimeout("tcp", frontURL.Host, 5*time.Second)
if err != nil {
t.Fatalf("dial front: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
req := "GET / HTTP/1.1\r\n" +
"Host: " + frontURL.Host + "\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: websocket\r\n" +
"\r\n"
if _, err := conn.Write([]byte(req)); err != nil {
t.Fatalf("write upgrade request: %v", err)
}
br := bufio.NewReader(conn)
statusLine, err := br.ReadString('\n')
if err != nil {
t.Fatalf("read status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
t.Fatalf("websocket upgrade failed: status line = %q, want 101 Switching Protocols", strings.TrimSpace(statusLine))
}
// Drain the rest of the response headers.
for {
line, err := br.ReadString('\n')
if err != nil {
t.Fatalf("read headers: %v", err)
}
if strings.TrimSpace(line) == "" {
break
}
}
// Verify bytes flow through the hijacked connection.
if _, err := conn.Write([]byte("ping")); err != nil {
t.Fatalf("write payload: %v", err)
}
echo := make([]byte, 4)
if _, err := io.ReadFull(br, echo); err != nil {
t.Fatalf("read echo: %v", err)
}
if string(echo) != "ping" {
t.Errorf("echo = %q, want %q", echo, "ping")
}
}
+17
View File
@@ -1,12 +1,14 @@
package server package server
import ( import (
"bufio"
"bytes" "bytes"
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -427,6 +429,12 @@ func (w *responseBodyCopier) Write(b []byte) (int, error) {
if !w.wroteHeader { if !w.wroteHeader {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
// On a protocol upgrade (e.g. websocket) the body is raw framed data, not a
// metrics-parseable response, so write straight to the client without
// buffering a copy we can't use.
if w.status == http.StatusSwitchingProtocols {
return w.ResponseWriter.Write(b)
}
return w.tee.Write(b) return w.tee.Write(b)
} }
@@ -446,5 +454,14 @@ func (w *responseBodyCopier) Flush() {
} }
} }
// Hijack forwards to the underlying writer so httputil.ReverseProxy can take
// over the connection for websocket upgrades.
func (w *responseBodyCopier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking")
}
func (w *responseBodyCopier) Status() int { return w.status } func (w *responseBodyCopier) Status() int { return w.status }
func (w *responseBodyCopier) StartTime() time.Time { return w.start } func (w *responseBodyCopier) StartTime() time.Time { return w.start }