Compare commits

...

1 Commits

Author SHA1 Message Date
Benson Wong a15e47922c proxy: meter /upstream requests via metrics middleware (#858)
Wrap /upstream/{upstreamPath...} in the metrics middleware so activity
log entries are recorded for model-dispatched endpoints accessed through
the upstream passthrough.

- Move findModelInPath to shared.FindModelInPath and reuse it in
handleUpstream, the log monitor lookup, and FetchContext.
- Extend FetchContext to resolve the model from /upstream/<model>/...
paths without consuming the request body.
- Add isMetricsRecordPath to limit recording to the model-dispatched
endpoints that produce token usage/timings.
- Add tests for upstream metrics recording and FetchContext upstream
path resolution.

Fixes #855
2026-06-17 17:38:52 -07:00
8 changed files with 303 additions and 38 deletions
+1 -27
View File
@@ -314,7 +314,7 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
upstreamPath := r.PathValue("upstreamPath") upstreamPath := r.PathValue("upstreamPath")
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath) searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
if !found { if !found {
shared.SendResponse(w, r, http.StatusNotFound, "model not found") shared.SendResponse(w, r, http.StatusNotFound, "model not found")
return return
@@ -349,29 +349,3 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID) shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
} }
} }
// findModelInPath walks a slash-separated path, building up segments until one
// matches a configured model. This resolves model names that contain slashes
// (e.g. "author/model"). Returns the matched name, its real model ID, the
// remaining path, and whether a match was found.
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
name := ""
for i, part := range parts {
if part == "" {
continue
}
if name == "" {
name = part
} else {
name = name + "/" + part
}
if modelID, ok := cfg.RealModelName(name); ok {
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
+82 -2
View File
@@ -2,11 +2,15 @@ package server
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
) )
func TestServer_HandleListModels(t *testing.T) { func TestServer_HandleListModels(t *testing.T) {
@@ -78,6 +82,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) {
func TestServer_FindModelInPath(t *testing.T) { func TestServer_FindModelInPath(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{ cfg := config.Config{Models: map[string]config.ModelConfig{
"author": {},
"author/model": {}, "author/model": {},
"simple": {}, "simple": {},
}} }}
@@ -91,13 +96,14 @@ func TestServer_FindModelInPath(t *testing.T) {
{"/simple/v1/chat", "simple", "/v1/chat", true}, {"/simple/v1/chat", "simple", "/v1/chat", true},
{"/author/model/v1/chat", "author/model", "/v1/chat", true}, {"/author/model/v1/chat", "author/model", "/v1/chat", true},
{"/author/model", "author/model", "/", true}, {"/author/model", "author/model", "/", true},
{"/author/v1/chat", "author", "/v1/chat", true},
{"/missing/v1", "", "", false}, {"/missing/v1", "", "", false},
{"/", "", "", false}, {"/", "", "", false},
} }
for _, c := range cases { for _, c := range cases {
name, _, rem, found := findModelInPath(cfg, c.path) name, _, rem, found := shared.FindModelInPath(cfg, c.path)
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) { if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)", t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound) c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
} }
} }
@@ -133,6 +139,80 @@ func TestServer_HandleUpstream(t *testing.T) {
}) })
} }
func upstreamMetricsServer(response string) *Server {
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
proxylog := logmon.NewWriter(io.Discard)
s := &Server{
cfg: cfg,
muxlog: logmon.NewWriter(io.Discard),
proxylog: proxylog,
upstreamlog: logmon.NewWriter(io.Discard),
inflight: &inflightCounter{},
metrics: newMetricsMonitor(proxylog, 10, 0),
local: newStubRouter([]string{"m1"}, response),
peer: newStubRouter(nil, ""),
}
s.routes()
return s
}
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
s := upstreamMetricsServer(resp)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
s.ServeHTTP(w, req)
if w.Code != http.StatusOK || w.Body.String() != resp {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
entries := s.metrics.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 metrics entry, got %d", len(entries))
}
if entries[0].Model != "m1" {
t.Errorf("model = %q, want m1", entries[0].Model)
}
if entries[0].ReqPath != "/v1/chat/completions" {
t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath)
}
if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 {
t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens)
}
}
func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) {
s := upstreamMetricsServer("ok")
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
s.ServeHTTP(w, req)
if w.Code != http.StatusOK || w.Body.String() != "ok" {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if len(s.metrics.getMetrics()) != 0 {
t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics()))
}
}
func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) {
s := upstreamMetricsServer(`{"usage":{}}`)
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if len(s.metrics.getMetrics()) != 0 {
t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics()))
}
}
func TestServer_HandleMetrics_Unavailable(t *testing.T) { func TestServer_HandleMetrics_Unavailable(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
+1 -1
View File
@@ -76,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
case "upstream": case "upstream":
return s.upstreamlog, nil return s.upstreamlog, nil
default: default:
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found { if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found {
if log, ok := s.local.ProcessLogger(modelID); ok { if log, ok := s.local.ProcessLogger(modelID); ok {
return log, nil return log, nil
} }
+22 -2
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
@@ -21,8 +22,27 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
return return
} }
// Determine the model-routed endpoint path. Regular routes are
// already meterable; /upstream/<model>/<path> is metered only when
// the remaining path matches a model-dispatched endpoint.
checkPath := r.URL.Path
if strings.HasPrefix(r.URL.Path, "/upstream/") {
var found bool
_, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
if !found {
next.ServeHTTP(w, r)
return
}
}
if !isMetricsRecordPath(checkPath) {
next.ServeHTTP(w, r)
return
}
// Resolve the model now so downstream dispatch hits the context // Resolve the model now so downstream dispatch hits the context
// fast path; FetchContext restores the request body. // fast path; FetchContext restores the request body for regular
// routes and extracts the model from the URL for /upstream routes.
data, err := shared.FetchContext(r, cfg) data, err := shared.FetchContext(r, cfg)
if err != nil { if err != nil {
shared.SendError(w, r, shared.ErrNoModelInContext) shared.SendError(w, r, shared.ErrNoModelInContext)
@@ -31,7 +51,7 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
// Buffer the request body/headers for capture before dispatch // Buffer the request body/headers for capture before dispatch
// consumes them. // consumes them.
cf := captureFieldsFor(r.URL.Path) cf := captureFieldsFor(checkPath)
var reqBody []byte var reqBody []byte
var reqHeaders map[string]string var reqHeaders map[string]string
if mm.enableCaptures { if mm.enableCaptures {
+40
View File
@@ -1,12 +1,15 @@
package server package server
import ( import (
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared" "github.com/mostlygeek/llama-swap/internal/shared"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -103,3 +106,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) {
t.Fatalf("tokens = %+v", entry.Tokens) t.Fatalf("tokens = %+v", entry.Tokens)
} }
} }
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
// mask (headers only) rather than falling back to captureAll.
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "audio/mpeg")
w.WriteHeader(http.StatusOK)
w.Write([]byte("BINARY-AUDIO-DATA"))
})
handler := CreateMetricsMiddleware(mm, cfg)(inner)
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
handler.ServeHTTP(httptest.NewRecorder(), req)
entries := mm.getMetrics()
if len(entries) == 0 {
t.Fatal("no metrics recorded")
}
last := entries[len(entries)-1]
if !last.HasCapture {
t.Fatal("expected capture to be stored")
}
cap := mm.getCaptureByID(last.ID)
if cap == nil {
t.Fatal("capture not found")
}
if len(cap.RespBody) != 0 {
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
}
if len(cap.RespHeaders) == 0 {
t.Error("RespHeaders not stored; want captureRespHeaders mask")
}
}
+25 -2
View File
@@ -89,6 +89,27 @@ var modelGetRoutes = []string{
"/sdapi/v1/loras", "/sdapi/v1/loras",
} }
// isMetricsRecordPath reports whether path is one of the model-dispatched
// endpoints that the metrics middleware records in the activity log.
func isMetricsRecordPath(path string) bool {
for _, p := range modelPostJSONRoutes {
if p == path {
return true
}
}
for _, p := range modelPostFormRoutes {
if p == path {
return true
}
}
for _, p := range modelGetRoutes {
if p == path {
return true
}
}
return false
}
// BuildInfo carries version metadata surfaced by GET /api/version. // BuildInfo carries version metadata surfaced by GET /api/version.
type BuildInfo struct { type BuildInfo struct {
Version string Version string
@@ -219,9 +240,11 @@ func (s *Server) routes() {
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload)) mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning)) mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
// Upstream passthrough. // Upstream passthrough. Meter only the model-dispatched endpoints that can
// produce token usage/timings.
upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg))
mux.HandleFunc("GET /upstream", handleUpstreamRedirect) mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream)) mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream))
// API group (API-key protected) consumed by the UI. // API group (API-key protected) consumed by the UI.
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll)) mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
+65 -4
View File
@@ -91,16 +91,24 @@ func SendResponse(w http.ResponseWriter, r *http.Request, status int, message st
w.Write(resp) w.Write(resp)
} }
// FetchContext will attempt to get the model id from the context then // FetchContext will attempt to get the model id from the context, then
// from the model body. If it extracts the model from the body it will // from an /upstream/<model> path prefix, then from the request body/query.
// store the model in the context for downstream handlers. An error // If it extracts the model it will store it in the context for downstream
// will be returned when model can not be fetch from either location. // handlers. An error will be returned when a model cannot be identified.
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
data, ok := ReadContext(r.Context()) data, ok := ReadContext(r.Context())
if ok { if ok {
return data, nil return data, nil
} }
if strings.HasPrefix(r.URL.Path, "/upstream/") {
if data, ok := extractUpstreamContext(r, cfg); ok {
*r = *r.WithContext(SetContext(r.Context(), data))
return data, nil
}
return ReqContextData{}, ErrNoModelInContext
}
if data, err := extractContext(r); err == nil && data.Model != "" { if data, err := extractContext(r); err == nil && data.Model != "" {
realName, _ := cfg.RealModelName(data.Model) realName, _ := cfg.RealModelName(data.Model)
if realName == "" { if realName == "" {
@@ -117,6 +125,59 @@ func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
return ReqContextData{}, ErrNoModelInContext return ReqContextData{}, ErrNoModelInContext
} }
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
if !found {
return ReqContextData{}, false
}
return ReqContextData{
Model: searchName,
ModelID: realName,
ApiKey: ExtractAPIKey(r),
Streaming: r.URL.Query().Get("stream") == "true",
SendLoadingState: sendLoadingState(cfg, realName),
Metadata: make(map[string]string),
}, true
}
// sendLoadingState reports whether the configured model wants loading-state SSEs.
func sendLoadingState(cfg config.Config, modelID string) bool {
if mc, ok := cfg.Models[modelID]; ok {
return mc.SendLoadingState != nil && *mc.SendLoadingState
}
return false
}
// FindModelInPath walks a slash-separated path, building up segments until one
// matches a configured model. This resolves model names that contain slashes
// (e.g. "author/model"). Returns the matched name, its real model ID, the
// remaining path, and whether a match was found.
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
name := ""
for i, part := range parts {
if part == "" {
continue
}
if name == "" {
name = part
} else {
name = name + "/" + part
}
if modelID, ok := cfg.RealModelName(name); ok {
searchName = name
realName = modelID
remainingPath = "/" + strings.Join(parts[i+1:], "/")
found = true
}
}
return
}
func SetContext(ctx context.Context, data ReqContextData) context.Context { func SetContext(ctx context.Context, data ReqContextData) context.Context {
return context.WithValue(ctx, ReqContextKey, data) return context.WithValue(ctx, ReqContextKey, data)
} }
+67
View File
@@ -11,6 +11,8 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"github.com/mostlygeek/llama-swap/internal/config"
) )
func TestExtractContext_GET(t *testing.T) { func TestExtractContext_GET(t *testing.T) {
@@ -456,3 +458,68 @@ func TestServer_ExtractAPIKey(t *testing.T) {
}) })
} }
} }
func TestFetchContext_UpstreamPath(t *testing.T) {
cfg := config.Config{
Models: map[string]config.ModelConfig{
"m1": {},
"author/model": {},
"real": {Aliases: []string{"nick"}},
},
}
cases := []struct {
name string
path string
wantModel string
wantModelID string
wantErr bool
}{
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
{"bare model path", "/upstream/m1/", "m1", "m1", false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
data, err := FetchContext(r, cfg)
if (err != nil) != c.wantErr {
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
}
if c.wantErr {
return
}
if data.Model != c.wantModel {
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
}
if data.ModelID != c.wantModelID {
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
}
if data.Metadata == nil {
t.Error("metadata map not initialized")
}
})
}
}
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
body := `{"model":"should-not-matter"}`
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
_, err := FetchContext(r, cfg)
if err != nil {
t.Fatalf("FetchContext: %v", err)
}
// The body should be untouched so the upstream handler can still read it.
got, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
if string(got) != body {
t.Errorf("body was consumed: %q", string(got))
}
}