From a15e47922c3db3e24df75b2648acad5acafeda4b Mon Sep 17 00:00:00 2001 From: Benson Wong <83972+mostlygeek@users.noreply.github.com> Date: Wed, 17 Jun 2026 17:38:52 -0700 Subject: [PATCH] proxy: meter /upstream requests via metrics middleware (#858) Wrap /upstream/{upstreamPath...} in the metrics middleware so activity log entries are recorded for model-dispatched endpoints accessed through the upstream passthrough. - Move findModelInPath to shared.FindModelInPath and reuse it in handleUpstream, the log monitor lookup, and FetchContext. - Extend FetchContext to resolve the model from /upstream//... paths without consuming the request body. - Add isMetricsRecordPath to limit recording to the model-dispatched endpoints that produce token usage/timings. - Add tests for upstream metrics recording and FetchContext upstream path resolution. Fixes #855 --- internal/server/api.go | 28 +-------- internal/server/api_test.go | 84 ++++++++++++++++++++++++++- internal/server/log.go | 2 +- internal/server/metrics_middleware.go | 24 +++++++- internal/server/metrics_test.go | 40 +++++++++++++ internal/server/server.go | 27 ++++++++- internal/shared/http.go | 69 ++++++++++++++++++++-- internal/shared/http_test.go | 67 +++++++++++++++++++++ 8 files changed, 303 insertions(+), 38 deletions(-) diff --git a/internal/server/api.go b/internal/server/api.go index b9ed8ba3..96cafed8 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -314,7 +314,7 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) { func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { upstreamPath := r.PathValue("upstreamPath") - searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath) + searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath) if !found { shared.SendResponse(w, r, http.StatusNotFound, "model not found") return @@ -349,29 +349,3 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID) } } - -// findModelInPath walks a slash-separated path, building up segments until one -// matches a configured model. This resolves model names that contain slashes -// (e.g. "author/model"). Returns the matched name, its real model ID, the -// remaining path, and whether a match was found. -func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) { - parts := strings.Split(strings.TrimSpace(path), "/") - name := "" - - for i, part := range parts { - if part == "" { - continue - } - if name == "" { - name = part - } else { - name = name + "/" + part - } - - if modelID, ok := cfg.RealModelName(name); ok { - return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true - } - } - - return "", "", "", false -} diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 7b92d6b2..5924bc06 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -2,11 +2,15 @@ package server import ( "encoding/json" + "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/shared" ) func TestServer_HandleListModels(t *testing.T) { @@ -78,6 +82,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) { func TestServer_FindModelInPath(t *testing.T) { cfg := config.Config{Models: map[string]config.ModelConfig{ + "author": {}, "author/model": {}, "simple": {}, }} @@ -91,13 +96,14 @@ func TestServer_FindModelInPath(t *testing.T) { {"/simple/v1/chat", "simple", "/v1/chat", true}, {"/author/model/v1/chat", "author/model", "/v1/chat", true}, {"/author/model", "author/model", "/", true}, + {"/author/v1/chat", "author", "/v1/chat", true}, {"/missing/v1", "", "", false}, {"/", "", "", false}, } for _, c := range cases { - name, _, rem, found := findModelInPath(cfg, c.path) + name, _, rem, found := shared.FindModelInPath(cfg, c.path) if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) { - t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)", + t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)", c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound) } } @@ -133,6 +139,80 @@ func TestServer_HandleUpstream(t *testing.T) { }) } +func upstreamMetricsServer(response string) *Server { + cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}} + proxylog := logmon.NewWriter(io.Discard) + s := &Server{ + cfg: cfg, + muxlog: logmon.NewWriter(io.Discard), + proxylog: proxylog, + upstreamlog: logmon.NewWriter(io.Discard), + inflight: &inflightCounter{}, + metrics: newMetricsMonitor(proxylog, 10, 0), + local: newStubRouter([]string{"m1"}, response), + peer: newStubRouter(nil, ""), + } + s.routes() + return s +} + +func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) { + resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}` + s := upstreamMetricsServer(resp) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + s.ServeHTTP(w, req) + + if w.Code != http.StatusOK || w.Body.String() != resp { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + entries := s.metrics.getMetrics() + if len(entries) != 1 { + t.Fatalf("want 1 metrics entry, got %d", len(entries)) + } + if entries[0].Model != "m1" { + t.Errorf("model = %q, want m1", entries[0].Model) + } + if entries[0].ReqPath != "/v1/chat/completions" { + t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath) + } + if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 { + t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens) + } +} + +func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) { + s := upstreamMetricsServer("ok") + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + s.ServeHTTP(w, req) + + if w.Code != http.StatusOK || w.Body.String() != "ok" { + t.Fatalf("status=%d body=%q", w.Code, w.Body.String()) + } + if len(s.metrics.getMetrics()) != 0 { + t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics())) + } +} + +func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) { + s := upstreamMetricsServer(`{"usage":{}}`) + + w := httptest.NewRecorder() + s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil)) + + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if len(s.metrics.getMetrics()) != 0 { + t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics())) + } +} + func TestServer_HandleMetrics_Unavailable(t *testing.T) { s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) diff --git a/internal/server/log.go b/internal/server/log.go index bd776780..570a34fa 100644 --- a/internal/server/log.go +++ b/internal/server/log.go @@ -76,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) { case "upstream": return s.upstreamlog, nil default: - if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found { + if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found { if log, ok := s.local.ProcessLogger(modelID); ok { return log, nil } diff --git a/internal/server/metrics_middleware.go b/internal/server/metrics_middleware.go index 8ba41825..e74eec86 100644 --- a/internal/server/metrics_middleware.go +++ b/internal/server/metrics_middleware.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "strings" "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" @@ -21,8 +22,27 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle return } + // Determine the model-routed endpoint path. Regular routes are + // already meterable; /upstream// is metered only when + // the remaining path matches a model-dispatched endpoint. + checkPath := r.URL.Path + if strings.HasPrefix(r.URL.Path, "/upstream/") { + var found bool + _, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream")) + if !found { + next.ServeHTTP(w, r) + return + } + } + + if !isMetricsRecordPath(checkPath) { + next.ServeHTTP(w, r) + return + } + // Resolve the model now so downstream dispatch hits the context - // fast path; FetchContext restores the request body. + // fast path; FetchContext restores the request body for regular + // routes and extracts the model from the URL for /upstream routes. data, err := shared.FetchContext(r, cfg) if err != nil { shared.SendError(w, r, shared.ErrNoModelInContext) @@ -31,7 +51,7 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle // Buffer the request body/headers for capture before dispatch // consumes them. - cf := captureFieldsFor(r.URL.Path) + cf := captureFieldsFor(checkPath) var reqBody []byte var reqHeaders map[string]string if mm.enableCaptures { diff --git a/internal/server/metrics_test.go b/internal/server/metrics_test.go index 8f061710..e404e4b1 100644 --- a/internal/server/metrics_test.go +++ b/internal/server/metrics_test.go @@ -1,12 +1,15 @@ package server import ( + "io" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/shared" "github.com/tidwall/gjson" ) @@ -103,3 +106,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) { t.Fatalf("tokens = %+v", entry.Tokens) } } + +// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that +// an /upstream//v1/audio/speech request uses the path-specific capture +// mask (headers only) rather than falling back to captureAll. +func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) { + mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5) + cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}} + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "audio/mpeg") + w.WriteHeader(http.StatusOK) + w.Write([]byte("BINARY-AUDIO-DATA")) + }) + handler := CreateMetricsMiddleware(mm, cfg)(inner) + + req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`)) + handler.ServeHTTP(httptest.NewRecorder(), req) + + entries := mm.getMetrics() + if len(entries) == 0 { + t.Fatal("no metrics recorded") + } + last := entries[len(entries)-1] + if !last.HasCapture { + t.Fatal("expected capture to be stored") + } + cap := mm.getCaptureByID(last.ID) + if cap == nil { + t.Fatal("capture not found") + } + if len(cap.RespBody) != 0 { + t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody)) + } + if len(cap.RespHeaders) == 0 { + t.Error("RespHeaders not stored; want captureRespHeaders mask") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 739e0c0d..0d27d89d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -89,6 +89,27 @@ var modelGetRoutes = []string{ "/sdapi/v1/loras", } +// isMetricsRecordPath reports whether path is one of the model-dispatched +// endpoints that the metrics middleware records in the activity log. +func isMetricsRecordPath(path string) bool { + for _, p := range modelPostJSONRoutes { + if p == path { + return true + } + } + for _, p := range modelPostFormRoutes { + if p == path { + return true + } + } + for _, p := range modelGetRoutes { + if p == path { + return true + } + } + return false +} + // BuildInfo carries version metadata surfaced by GET /api/version. type BuildInfo struct { Version string @@ -219,9 +240,11 @@ func (s *Server) routes() { mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload)) mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning)) - // Upstream passthrough. + // Upstream passthrough. Meter only the model-dispatched endpoints that can + // produce token usage/timings. + upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg)) mux.HandleFunc("GET /upstream", handleUpstreamRedirect) - mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream)) + mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream)) // API group (API-key protected) consumed by the UI. mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll)) diff --git a/internal/shared/http.go b/internal/shared/http.go index 30609b7c..ea0d3dfe 100644 --- a/internal/shared/http.go +++ b/internal/shared/http.go @@ -91,16 +91,24 @@ func SendResponse(w http.ResponseWriter, r *http.Request, status int, message st w.Write(resp) } -// FetchContext will attempt to get the model id from the context then -// from the model body. If it extracts the model from the body it will -// store the model in the context for downstream handlers. An error -// will be returned when model can not be fetch from either location. +// FetchContext will attempt to get the model id from the context, then +// from an /upstream/ path prefix, then from the request body/query. +// If it extracts the model it will store it in the context for downstream +// handlers. An error will be returned when a model cannot be identified. func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { data, ok := ReadContext(r.Context()) if ok { return data, nil } + if strings.HasPrefix(r.URL.Path, "/upstream/") { + if data, ok := extractUpstreamContext(r, cfg); ok { + *r = *r.WithContext(SetContext(r.Context(), data)) + return data, nil + } + return ReqContextData{}, ErrNoModelInContext + } + if data, err := extractContext(r); err == nil && data.Model != "" { realName, _ := cfg.RealModelName(data.Model) if realName == "" { @@ -117,6 +125,59 @@ func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { return ReqContextData{}, ErrNoModelInContext } +// extractUpstreamContext resolves the model from an /upstream//... path. +func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) { + searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream")) + if !found { + return ReqContextData{}, false + } + return ReqContextData{ + Model: searchName, + ModelID: realName, + ApiKey: ExtractAPIKey(r), + Streaming: r.URL.Query().Get("stream") == "true", + SendLoadingState: sendLoadingState(cfg, realName), + Metadata: make(map[string]string), + }, true +} + +// sendLoadingState reports whether the configured model wants loading-state SSEs. +func sendLoadingState(cfg config.Config, modelID string) bool { + if mc, ok := cfg.Models[modelID]; ok { + return mc.SendLoadingState != nil && *mc.SendLoadingState + } + return false +} + +// FindModelInPath walks a slash-separated path, building up segments until one +// matches a configured model. This resolves model names that contain slashes +// (e.g. "author/model"). Returns the matched name, its real model ID, the +// remaining path, and whether a match was found. +func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) { + parts := strings.Split(strings.TrimSpace(path), "/") + name := "" + + for i, part := range parts { + if part == "" { + continue + } + if name == "" { + name = part + } else { + name = name + "/" + part + } + + if modelID, ok := cfg.RealModelName(name); ok { + searchName = name + realName = modelID + remainingPath = "/" + strings.Join(parts[i+1:], "/") + found = true + } + } + + return +} + func SetContext(ctx context.Context, data ReqContextData) context.Context { return context.WithValue(ctx, ReqContextKey, data) } diff --git a/internal/shared/http_test.go b/internal/shared/http_test.go index c6a88072..e2ab545d 100644 --- a/internal/shared/http_test.go +++ b/internal/shared/http_test.go @@ -11,6 +11,8 @@ import ( "net/url" "strings" "testing" + + "github.com/mostlygeek/llama-swap/internal/config" ) func TestExtractContext_GET(t *testing.T) { @@ -456,3 +458,68 @@ func TestServer_ExtractAPIKey(t *testing.T) { }) } } + +func TestFetchContext_UpstreamPath(t *testing.T) { + cfg := config.Config{ + Models: map[string]config.ModelConfig{ + "m1": {}, + "author/model": {}, + "real": {Aliases: []string{"nick"}}, + }, + } + + cases := []struct { + name string + path string + wantModel string + wantModelID string + wantErr bool + }{ + {"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false}, + {"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false}, + {"unknown model", "/upstream/nope/v1/chat/completions", "", "", true}, + {"bare model path", "/upstream/m1/", "m1", "m1", false}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`)) + data, err := FetchContext(r, cfg) + if (err != nil) != c.wantErr { + t.Fatalf("wantErr=%v got err=%v", c.wantErr, err) + } + if c.wantErr { + return + } + if data.Model != c.wantModel { + t.Errorf("model = %q, want %q", data.Model, c.wantModel) + } + if data.ModelID != c.wantModelID { + t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID) + } + if data.Metadata == nil { + t.Error("metadata map not initialized") + } + }) + } +} + +func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) { + cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}} + body := `{"model":"should-not-matter"}` + r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body)) + + _, err := FetchContext(r, cfg) + if err != nil { + t.Fatalf("FetchContext: %v", err) + } + + // The body should be untouched so the upstream handler can still read it. + got, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if string(got) != body { + t.Errorf("body was consumed: %q", string(got)) + } +}