Files
llama-swap/internal/server/metrics_test.go
T
Benson Wong a15e47922c proxy: meter /upstream requests via metrics middleware (#858)
Wrap /upstream/{upstreamPath...} in the metrics middleware so activity
log entries are recorded for model-dispatched endpoints accessed through
the upstream passthrough.

- Move findModelInPath to shared.FindModelInPath and reuse it in
handleUpstream, the log monitor lookup, and FetchContext.
- Extend FetchContext to resolve the model from /upstream/<model>/...
paths without consuming the request body.
- Add isMetricsRecordPath to limit recording to the model-dispatched
endpoints that produce token usage/timings.
- Add tests for upstream metrics recording and FetchContext upstream
path resolution.

Fixes #855
2026-06-17 17:38:52 -07:00

146 lines
5.0 KiB
Go

package server
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
"github.com/tidwall/gjson"
)
func TestServer_ParseMetrics_ChatCompletions(t *testing.T) {
body := `{"usage":{"prompt_tokens":12,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}`
parsed := gjson.Parse(body)
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
if err != nil {
t.Fatalf("parseMetrics: %v", err)
}
if entry.Tokens.InputTokens != 12 || entry.Tokens.OutputTokens != 7 || entry.Tokens.CachedTokens != 4 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
}
func TestServer_ParseMetrics_Timings(t *testing.T) {
body := `{"timings":{"prompt_n":20,"predicted_n":50,"prompt_per_second":100.0,"predicted_per_second":40.0,"prompt_ms":200,"predicted_ms":1250,"cache_n":8}}`
parsed := gjson.Parse(body)
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
if err != nil {
t.Fatalf("parseMetrics: %v", err)
}
if entry.Tokens.InputTokens != 20 || entry.Tokens.OutputTokens != 50 || entry.Tokens.CachedTokens != 8 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
if entry.Tokens.TokensPerSecond != 40.0 || entry.Tokens.PromptPerSecond != 100.0 {
t.Fatalf("rates = %+v", entry.Tokens)
}
if entry.DurationMs != 1450 {
t.Fatalf("DurationMs = %d, want 1450", entry.DurationMs)
}
}
func TestServer_ProcessStreamingResponse(t *testing.T) {
body := []byte("data: {\"choices\":[{}]}\n\n" +
"data: {\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":33}}\n\n" +
"data: [DONE]\n\n")
entry, err := processStreamingResponse("m", time.Now(), body)
if err != nil {
t.Fatalf("processStreamingResponse: %v", err)
}
if entry.Tokens.InputTokens != 15 || entry.Tokens.OutputTokens != 33 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
}
func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
if _, err := processStreamingResponse("m", time.Now(), []byte("data: [DONE]\n\n")); err == nil {
t.Fatal("expected error for stream with no usage data")
}
}
func TestMetricsMonitor_RecordMetadata(t *testing.T) {
mm := newMetricsMonitor(nil, 10, 0)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"usage":{}}`))
r = r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{
ModelID: "m",
Metadata: map[string]string{"client": "web", "trace": "abc"},
}))
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.WriteHeader(http.StatusOK)
copier.Write([]byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
mm.record("m", r, copier, 0, nil, nil)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
if entries[0].Metadata["client"] != "web" {
t.Errorf("client = %q, want web", entries[0].Metadata["client"])
}
if entries[0].Metadata["trace"] != "abc" {
t.Errorf("trace = %q, want abc", entries[0].Metadata["trace"])
}
}
func TestServer_ParseMetrics_Infill(t *testing.T) {
// /infill responses are arrays; timings live in the last element.
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
parsed := gjson.Parse(body)
timings := parsed.Get("timings")
if arr := parsed.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), timings)
if err != nil {
t.Fatalf("parseMetrics: %v", err)
}
if entry.Tokens.InputTokens != 5 || entry.Tokens.OutputTokens != 9 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
}
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
// mask (headers only) rather than falling back to captureAll.
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "audio/mpeg")
w.WriteHeader(http.StatusOK)
w.Write([]byte("BINARY-AUDIO-DATA"))
})
handler := CreateMetricsMiddleware(mm, cfg)(inner)
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
handler.ServeHTTP(httptest.NewRecorder(), req)
entries := mm.getMetrics()
if len(entries) == 0 {
t.Fatal("no metrics recorded")
}
last := entries[len(entries)-1]
if !last.HasCapture {
t.Fatal("expected capture to be stored")
}
cap := mm.getCaptureByID(last.ID)
if cap == nil {
t.Fatal("capture not found")
}
if len(cap.RespBody) != 0 {
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
}
if len(cap.RespHeaders) == 0 {
t.Error("RespHeaders not stored; want captureRespHeaders mask")
}
}