diff --git a/go.mod b/go.mod index d417f215..314ea4be 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.26.1 require ( github.com/billziss-gh/golib v0.2.0 + github.com/fxamacker/cbor/v2 v2.9.1 github.com/gin-gonic/gin v1.10.0 github.com/klauspost/compress v1.18.5 github.com/stretchr/testify v1.9.0 @@ -36,6 +37,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/x448/float16 v0.8.4 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/net v0.47.0 // indirect diff --git a/go.sum b/go.sum index ca65bc88..18d4093e 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ= +github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -77,6 +79,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= diff --git a/proxy/cache/cache.go b/proxy/cache/cache.go new file mode 100644 index 00000000..f3310058 --- /dev/null +++ b/proxy/cache/cache.go @@ -0,0 +1,102 @@ +package cache + +import ( + "errors" + "sync" +) + +var ( + ErrExceedsMaxSize = errors.New("item exceeds maximum cache size") + ErrNotFound = errors.New("item not found") +) + +type Cache struct { + mu sync.Mutex + items map[int][]byte + order []int + size int + maxSize int +} + +func New(maxBytes int) *Cache { + return &Cache{ + items: make(map[int][]byte), + order: make([]int, 0), + maxSize: maxBytes, + } +} + +func (c *Cache) Add(id int, data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + dataSize := len(data) + if dataSize > c.maxSize { + return ErrExceedsMaxSize + } + + // If key already exists, remove old entry from size and order + if old, exists := c.items[id]; exists { + c.size -= len(old) + c.removeOrder(id) + } + + // Evict oldest (FIFO) until room available + for c.size+dataSize > c.maxSize && len(c.order) > 0 { + oldestID := c.order[0] + c.order = c.order[1:] + if evicted, exists := c.items[oldestID]; exists { + c.size -= len(evicted) + delete(c.items, oldestID) + } + } + + c.items[id] = data + c.order = append(c.order, id) + c.size += dataSize + return nil +} + +func (c *Cache) removeOrder(id int) { + for i, v := range c.order { + if v == id { + c.order = append(c.order[:i], c.order[i+1:]...) + return + } + } +} + +func (c *Cache) Get(id int) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + + data, exists := c.items[id] + if !exists { + return nil, ErrNotFound + } + return data, nil +} + +func (c *Cache) Has(id int) bool { + c.mu.Lock() + defer c.mu.Unlock() + + _, exists := c.items[id] + return exists +} + +func (c *Cache) Size() int { + c.mu.Lock() + defer c.mu.Unlock() + + return c.size +} + +func (c *Cache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.items = make(map[int][]byte) + c.order = c.order[:0] + c.size = 0 +} diff --git a/proxy/cache/cache_test.go b/proxy/cache/cache_test.go new file mode 100644 index 00000000..8443d79f --- /dev/null +++ b/proxy/cache/cache_test.go @@ -0,0 +1,130 @@ +package cache + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCache_Add(t *testing.T) { + t.Run("adds and retrieves item", func(t *testing.T) { + c := New(1024) + data := []byte("hello") + require.NoError(t, c.Add(1, data)) + + got, err := c.Get(1) + require.NoError(t, err) + assert.Equal(t, data, got) + }) + + t.Run("returns error for oversized item", func(t *testing.T) { + c := New(10) + err := c.Add(1, make([]byte, 20)) + assert.ErrorIs(t, err, ErrExceedsMaxSize) + }) + + t.Run("evicts oldest items to make room", func(t *testing.T) { + c := New(100) + + require.NoError(t, c.Add(1, make([]byte, 40))) + require.NoError(t, c.Add(2, make([]byte, 40))) + // Adding item 3 should evict item 1 + require.NoError(t, c.Add(3, make([]byte, 40))) + + assert.False(t, c.Has(1)) + assert.True(t, c.Has(2)) + assert.True(t, c.Has(3)) + }) + + t.Run("overwrites existing key", func(t *testing.T) { + c := New(100) + require.NoError(t, c.Add(1, []byte("old"))) + require.NoError(t, c.Add(1, []byte("new"))) + + got, err := c.Get(1) + require.NoError(t, err) + assert.Equal(t, []byte("new"), got) + assert.Equal(t, 3, c.Size()) + }) +} + +func TestCache_Get(t *testing.T) { + t.Run("returns ErrNotFound for missing key", func(t *testing.T) { + c := New(100) + _, err := c.Get(99) + assert.ErrorIs(t, err, ErrNotFound) + }) +} + +func TestCache_Has(t *testing.T) { + t.Run("returns true for existing key", func(t *testing.T) { + c := New(100) + require.NoError(t, c.Add(1, []byte("data"))) + assert.True(t, c.Has(1)) + }) + + t.Run("returns false for missing key", func(t *testing.T) { + c := New(100) + assert.False(t, c.Has(1)) + }) +} + +func TestCache_Size(t *testing.T) { + t.Run("tracks byte usage", func(t *testing.T) { + c := New(1000) + assert.Equal(t, 0, c.Size()) + + require.NoError(t, c.Add(1, make([]byte, 100))) + assert.Equal(t, 100, c.Size()) + + require.NoError(t, c.Add(2, make([]byte, 200))) + assert.Equal(t, 300, c.Size()) + }) + + t.Run("updates on eviction", func(t *testing.T) { + c := New(150) + require.NoError(t, c.Add(1, make([]byte, 100))) + require.NoError(t, c.Add(2, make([]byte, 100))) + + // Item 1 should be evicted, size = 100 + assert.Equal(t, 100, c.Size()) + }) +} + +func TestCache_Clear(t *testing.T) { + t.Run("removes all items and resets size", func(t *testing.T) { + c := New(1000) + require.NoError(t, c.Add(1, []byte("a"))) + require.NoError(t, c.Add(2, []byte("b"))) + + c.Clear() + + assert.Equal(t, 0, c.Size()) + assert.False(t, c.Has(1)) + assert.False(t, c.Has(2)) + }) +} + +func TestCache_Concurrent(t *testing.T) { + t.Run("concurrent operations are safe", func(t *testing.T) { + c := New(10000) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 100; j++ { + key := id*100 + j + _ = c.Add(key, []byte("data")) + _, _ = c.Get(key) + _ = c.Has(key) + _ = c.Size() + } + }(i) + } + wg.Wait() + }) +} diff --git a/proxy/events.go b/proxy/events.go index e35ee627..9daec0c9 100644 --- a/proxy/events.go +++ b/proxy/events.go @@ -6,7 +6,7 @@ const ProcessStateChangeEventID = 0x01 const ChatCompletionStatsEventID = 0x02 const ConfigFileChangedEventID = 0x03 const LogDataEventID = 0x04 -const TokenMetricsEventID = 0x05 +const ActivityLogEventID = 0x05 const ModelPreloadedEventID = 0x06 const InFlightRequestsEventID = 0x07 diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go index 283c1662..94d2e9bd 100644 --- a/proxy/metrics_monitor.go +++ b/proxy/metrics_monitor.go @@ -12,9 +12,11 @@ import ( "sync" "time" + "github.com/fxamacker/cbor/v2" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/proxy/cache" "github.com/tidwall/gjson" ) @@ -42,37 +44,53 @@ var zstdDecPool = &sync.Pool{ }, } -// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd. -// Returns compressed bytes and the original JSON byte count for logging. +// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd. +// Returns compressed bytes and the original CBOR byte count for logging. func compressCapture(c *ReqRespCapture) ([]byte, int, error) { - jsonBytes, err := json.Marshal(c) + cborBytes, err := cbor.Marshal(c) if err != nil { return nil, 0, fmt.Errorf("marshal capture: %w", err) } - enc := zstdEncPool.Get().(*zstd.Encoder) - defer zstdEncPool.Put(enc) - return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil + zenc := zstdEncPool.Get().(*zstd.Encoder) + defer zstdEncPool.Put(zenc) + return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil } -// decompressCapture decompresses zstd-compressed JSON and returns it. -func decompressCapture(data []byte) ([]byte, error) { +// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture. +func decompressCapture(data []byte) (*ReqRespCapture, error) { dec := zstdDecPool.Get().(*zstd.Decoder) defer zstdDecPool.Put(dec) - return dec.DecodeAll(data, nil) + cborBytes, err := dec.DecodeAll(data, nil) + if err != nil { + return nil, fmt.Errorf("decompress capture: %w", err) + } + var capture ReqRespCapture + if err := cbor.Unmarshal(cborBytes, &capture); err != nil { + return nil, fmt.Errorf("unmarshal capture: %w", err) + } + return &capture, nil } -// TokenMetrics represents parsed token statistics from llama-server logs +// TokenMetrics holds token usage and performance metrics type TokenMetrics struct { - ID int `json:"id"` - Timestamp time.Time `json:"timestamp"` - Model string `json:"model"` - CachedTokens int `json:"cache_tokens"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - PromptPerSecond float64 `json:"prompt_per_second"` - TokensPerSecond float64 `json:"tokens_per_second"` - DurationMs int `json:"duration_ms"` - HasCapture bool `json:"has_capture"` + CachedTokens int `json:"cache_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + PromptPerSecond float64 `json:"prompt_per_second"` + TokensPerSecond float64 `json:"tokens_per_second"` +} + +// ActivityLogEntry represents parsed token statistics from llama-server logs +type ActivityLogEntry struct { + ID int `json:"id"` + Timestamp time.Time `json:"timestamp"` + Model string `json:"model"` + ReqPath string `json:"req_path"` + RespContentType string `json:"resp_content_type"` + RespStatusCode int `json:"resp_status_code"` + Tokens TokenMetrics `json:"tokens"` + DurationMs int `json:"duration_ms"` + HasCapture bool `json:"has_capture"` } type ReqRespCapture struct { @@ -84,48 +102,45 @@ type ReqRespCapture struct { RespBody []byte `json:"resp_body"` } -// TokenMetricsEvent represents a token metrics event -type TokenMetricsEvent struct { - Metrics TokenMetrics +// ActivityLogEvent represents a token metrics event +type ActivityLogEvent struct { + Metrics ActivityLogEntry } -func (e TokenMetricsEvent) Type() uint32 { - return TokenMetricsEventID // defined in events.go +func (e ActivityLogEvent) Type() uint32 { + return ActivityLogEventID // defined in events.go } // metricsMonitor parses llama-server output for token statistics type metricsMonitor struct { mu sync.RWMutex - metrics []TokenMetrics + metrics []ActivityLogEntry maxMetrics int nextID int logger *LogMonitor // capture fields enableCaptures bool - captures map[int][]byte // zstd-compressed JSON of ReqRespCapture - captureOrder []int // track insertion order for FIFO eviction - captureSize int // current total compressed size in bytes - maxCaptureSize int // max bytes for captures (uncompressed) + captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture } // newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the // capture buffer size in megabytes; 0 disables captures. func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor { - return &metricsMonitor{ + mm := &metricsMonitor{ logger: logger, maxMetrics: maxMetrics, enableCaptures: captureBufferMB > 0, - captures: make(map[int][]byte), - captureOrder: make([]int, 0), - captureSize: 0, - maxCaptureSize: captureBufferMB * 1024 * 1024, } + if captureBufferMB > 0 { + mm.captureCache = cache.New(captureBufferMB * 1024 * 1024) + } + return mm } -// addMetrics adds a new metric to the collection and publishes an event. -// Returns the assigned metric ID. -func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int { +// queueMetrics adds a new metric to the collection without emitting an event. +// Returns the assigned metric ID. Call emitMetric after capture setup. +func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int { mp.mu.Lock() defer mp.mu.Unlock() @@ -135,93 +150,75 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int { if len(mp.metrics) > mp.maxMetrics { mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:] } - event.Emit(TokenMetricsEvent{Metrics: metric}) return metric.ID } -// addCapture adds a new capture to the buffer with size-based eviction. -// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize. -func (mp *metricsMonitor) addCapture(capture ReqRespCapture) { +// emitMetric publishes an ActivityLogEvent for the given metric. +func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) { + event.Emit(ActivityLogEvent{Metrics: metric}) +} + +// addCapture compresses and stores a capture in the cache. +// Returns true if the capture was stored, false otherwise. +func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool { if !mp.enableCaptures { - return + return false } compressed, uncompressedBytes, err := compressCapture(&capture) if err != nil { mp.logger.Warnf("failed to compress capture: %v, skipping", err) - return + return false } - captureSize := len(compressed) - if captureSize > mp.maxCaptureSize { - mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize) - return + if err := mp.captureCache.Add(capture.ID, compressed); err != nil { + mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err) + return false } - compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100 - - mp.mu.Lock() - defer mp.mu.Unlock() - - // Evict oldest (FIFO) until room available for the compressed data - for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 { - oldestID := mp.captureOrder[0] - mp.captureOrder = mp.captureOrder[1:] - if evicted, exists := mp.captures[oldestID]; exists { - l := len(evicted) - mp.captureSize -= l - delete(mp.captures, oldestID) - mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l) - } - } - - mp.captures[capture.ID] = compressed - mp.captureOrder = append(mp.captureOrder, capture.ID) - mp.captureSize += captureSize - + compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100 mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio) + return true } // getCompressedBytes returns the raw compressed bytes for a capture by ID. func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) { - mp.mu.RLock() - defer mp.mu.RUnlock() - - data, exists := mp.captures[id] - return data, exists + if mp.captureCache == nil { + return nil, false + } + data, err := mp.captureCache.Get(id) + if err != nil { + return nil, false + } + return data, true } -// getCaptureByID returns decompressed capture bytes if found and decompress=true. -// If decompress=false, returns the raw zstd-compressed bytes. -// Returns nil if the capture is not found. -func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte { - mp.mu.RLock() - defer mp.mu.RUnlock() - - data, exists := mp.captures[id] +// getCaptureByID decompresses and unmarshals a capture by ID. +// Returns nil if the capture is not found or decompression fails. +func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture { + if mp.captureCache == nil { + return nil + } + data, exists := mp.getCompressedBytes(id) if !exists { return nil } - if !decompress { - return data - } - - decompressed, err := decompressCapture(data) + capture, err := decompressCapture(data) if err != nil { mp.logger.Warnf("failed to decompress capture %d: %v", id, err) return nil } - return decompressed + return capture } // getMetrics returns a copy of the current metrics -func (mp *metricsMonitor) getMetrics() []TokenMetrics { +func (mp *metricsMonitor) getMetrics() []ActivityLogEntry { mp.mu.RLock() defer mp.mu.RUnlock() - result := make([]TokenMetrics, len(mp.metrics)) + result := make([]ActivityLogEntry, len(mp.metrics)) copy(result, mp.metrics) return result } @@ -230,22 +227,52 @@ func (mp *metricsMonitor) getMetrics() []TokenMetrics { func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) { mp.mu.RLock() defer mp.mu.RUnlock() - return json.Marshal(mp.metrics) + + if mp.captureCache == nil { + return json.Marshal(mp.metrics) + } + + // Make a copy with up-to-date has_capture from cache + result := make([]ActivityLogEntry, len(mp.metrics)) + for i, m := range mp.metrics { + m.HasCapture = mp.captureCache.Has(m.ID) + result[i] = m + } + return json.Marshal(result) } -// wrapHandler wraps the proxy handler to extract token metrics +// Capture field flags for controlling what is saved in ReqRespCapture. +type captureFields uint + +const ( + captureNone captureFields = 1 << iota + captureReqHeaders + captureReqBody + captureRespHeaders + captureRespBody +) + +const ( + captureReqAll = captureReqHeaders | captureReqBody + captureRespAll = captureRespHeaders | captureRespBody + captureAll = captureReqAll | captureRespAll +) + +// wrapHandler wraps the proxy handler to extract token metrics. +// captureFields controls what is saved in the ReqRespCapture using bitwise flags. // if wrapHandler returns an error it is safe to assume that no // data was sent to the client func (mp *metricsMonitor) wrapHandler( modelID string, writer gin.ResponseWriter, request *http.Request, + captureFields captureFields, next func(modelID string, w http.ResponseWriter, r *http.Request) error, ) error { // Capture request body and headers if captures enabled var reqBody []byte var reqHeaders map[string]string - if mp.enableCaptures { + if mp.enableCaptures && (captureFields&captureReqBody) != 0 { if request.Body != nil { var err error reqBody, err = io.ReadAll(request.Body) @@ -255,6 +282,8 @@ func (mp *metricsMonitor) wrapHandler( request.Body.Close() request.Body = io.NopCloser(bytes.NewBuffer(reqBody)) } + } + if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 { reqHeaders = make(map[string]string) for key, values := range request.Header { if len(values) > 0 { @@ -278,22 +307,28 @@ func (mp *metricsMonitor) wrapHandler( // after this point we have to assume that data was sent to the client // and we can only log errors but not send them to clients - if recorder.Status() != http.StatusOK { - mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path) - return nil + // Initialize default metrics - recorded for every request + tm := ActivityLogEntry{ + Timestamp: time.Now(), + Model: modelID, + ReqPath: request.URL.Path, + RespContentType: recorder.Header().Get("Content-Type"), + RespStatusCode: recorder.Status(), + DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), } - // Initialize default metrics - these will always be recorded - tm := TokenMetrics{ - Timestamp: time.Now(), - Model: modelID, - DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), + if recorder.Status() != http.StatusOK { + mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path) + tm.ID = mp.queueMetrics(tm) + mp.emitMetric(tm) + return nil } body := recorder.body.Bytes() if len(body) == 0 { mp.logger.Warn("metrics: empty body, recording minimal metrics") - mp.addMetrics(tm) + tm.ID = mp.queueMetrics(tm) + mp.emitMetric(tm) return nil } @@ -303,7 +338,8 @@ func (mp *metricsMonitor) wrapHandler( body, err = decompressBody(body, encoding) if err != nil { mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path) - mp.addMetrics(tm) + tm.ID = mp.queueMetrics(tm) + mp.emitMetric(tm) return nil } } @@ -311,7 +347,8 @@ func (mp *metricsMonitor) wrapHandler( if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path) } else { - tm = parsed + tm.Tokens = parsed.Tokens + tm.DurationMs = parsed.DurationMs } } else { if gjson.ValidBytes(body) { @@ -331,7 +368,8 @@ func (mp *metricsMonitor) wrapHandler( if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path) } else { - tm = parsedMetrics + tm.Tokens = parsedMetrics.Tokens + tm.DurationMs = parsedMetrics.DurationMs } } } else { @@ -342,39 +380,50 @@ func (mp *metricsMonitor) wrapHandler( // Build capture if enabled and determine if it will be stored var capture *ReqRespCapture if mp.enableCaptures { - respHeaders := make(map[string]string) - for key, values := range recorder.Header() { - if len(values) > 0 { - respHeaders[key] = values[0] + var respHeaders map[string]string + var respBody []byte + if (captureFields & captureRespHeaders) != 0 { + respHeaders = make(map[string]string) + for key, values := range recorder.Header() { + if len(values) > 0 { + respHeaders[key] = values[0] + } } + redactHeaders(respHeaders) + delete(respHeaders, "Content-Encoding") + } + if (captureFields & captureRespBody) != 0 { + respBody = body } - redactHeaders(respHeaders) - delete(respHeaders, "Content-Encoding") capture = &ReqRespCapture{ ReqPath: request.URL.Path, ReqHeaders: reqHeaders, ReqBody: reqBody, RespHeaders: respHeaders, - RespBody: body, - } - compressed, _, err := compressCapture(capture) - if err == nil && len(compressed) <= mp.maxCaptureSize { - tm.HasCapture = true + RespBody: respBody, } } - metricID := mp.addMetrics(tm) + metricID := mp.queueMetrics(tm) + tm.ID = metricID // Store capture if enabled if capture != nil { capture.ID = metricID - mp.addCapture(*capture) + if mp.addCapture(*capture) { + tm.HasCapture = true + mp.mu.Lock() + mp.metrics[len(mp.metrics)-1].HasCapture = true + mp.mu.Unlock() + } } + mp.emitMetric(tm) + return nil } -func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) { +func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) { // Iterate **backwards** through the body looking for the data payload with // usage data. This avoids allocating a slice of all lines via bytes.Split. @@ -428,10 +477,10 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok } } - return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream") + return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream") } -func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) { +func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) { wallDurationMs := int(time.Since(start).Milliseconds()) // default values @@ -481,15 +530,17 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) } } - return TokenMetrics{ - Timestamp: time.Now(), - Model: modelID, - CachedTokens: cachedTokens, - InputTokens: inputTokens, - OutputTokens: outputTokens, - PromptPerSecond: promptPerSecond, - TokensPerSecond: tokensPerSecond, - DurationMs: durationMs, + return ActivityLogEntry{ + Timestamp: time.Now(), + Model: modelID, + Tokens: TokenMetrics{ + CachedTokens: cachedTokens, + InputTokens: inputTokens, + OutputTokens: outputTokens, + PromptPerSecond: promptPerSecond, + TokensPerSecond: tokensPerSecond, + }, + DurationMs: durationMs, }, nil } diff --git a/proxy/metrics_monitor_test.go b/proxy/metrics_monitor_test.go index 48372d9e..9fb737a5 100644 --- a/proxy/metrics_monitor_test.go +++ b/proxy/metrics_monitor_test.go @@ -12,8 +12,10 @@ import ( "testing" "time" + "github.com/fxamacker/cbor/v2" "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/proxy/cache" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -22,27 +24,29 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) { t.Run("adds metrics and assigns ID", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 0) - metric := TokenMetrics{ - Model: "test-model", - InputTokens: 100, - OutputTokens: 50, + metric := ActivityLogEntry{ + Model: "test-model", + Tokens: TokenMetrics{ + InputTokens: 100, + OutputTokens: 50, + }, } - mm.addMetrics(metric) + mm.queueMetrics(metric) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, 0, metrics[0].ID) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].InputTokens) - assert.Equal(t, 50, metrics[0].OutputTokens) + assert.Equal(t, 100, metrics[0].Tokens.InputTokens) + assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) }) t.Run("increments ID for each metric", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 0) for i := 0; i < 5; i++ { - mm.addMetrics(TokenMetrics{Model: "model"}) + mm.queueMetrics(ActivityLogEntry{Model: "model"}) } metrics := mm.getMetrics() @@ -57,9 +61,11 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) { // Add 5 metrics for i := 0; i < 5; i++ { - mm.addMetrics(TokenMetrics{ - Model: "model", - InputTokens: i, + mm.queueMetrics(ActivityLogEntry{ + Model: "model", + Tokens: TokenMetrics{ + InputTokens: i, + }, }) } @@ -72,29 +78,32 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) { assert.Equal(t, 4, metrics[2].ID) }) - t.Run("emits TokenMetricsEvent", func(t *testing.T) { + t.Run("emits ActivityLogEvent", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 0) - receivedEvent := make(chan TokenMetricsEvent, 1) - cancel := event.On(func(e TokenMetricsEvent) { + receivedEvent := make(chan ActivityLogEvent, 1) + cancel := event.On(func(e ActivityLogEvent) { receivedEvent <- e }) defer cancel() - metric := TokenMetrics{ - Model: "test-model", - InputTokens: 100, - OutputTokens: 50, + metric := ActivityLogEntry{ + Model: "test-model", + Tokens: TokenMetrics{ + InputTokens: 100, + OutputTokens: 50, + }, } - mm.addMetrics(metric) + mm.queueMetrics(metric) + mm.emitMetric(metric) select { case evt := <-receivedEvent: assert.Equal(t, 0, evt.Metrics.ID) assert.Equal(t, "test-model", evt.Metrics.Model) - assert.Equal(t, 100, evt.Metrics.InputTokens) - assert.Equal(t, 50, evt.Metrics.OutputTokens) + assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens) + assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for event") } @@ -111,8 +120,8 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) { t.Run("returns copy of metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 0) - mm.addMetrics(TokenMetrics{Model: "model1"}) - mm.addMetrics(TokenMetrics{Model: "model2"}) + mm.queueMetrics(ActivityLogEntry{Model: "model1"}) + mm.queueMetrics(ActivityLogEntry{Model: "model2"}) metrics1 := mm.getMetrics() metrics2 := mm.getMetrics() @@ -135,7 +144,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, jsonData) - var metrics []TokenMetrics + var metrics []ActivityLogEntry err = json.Unmarshal(jsonData, &metrics) assert.NoError(t, err) assert.Equal(t, 0, len(metrics)) @@ -143,23 +152,27 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) { t.Run("returns valid JSON with metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 0) - mm.addMetrics(TokenMetrics{ - Model: "model1", - InputTokens: 100, - OutputTokens: 50, - TokensPerSecond: 25.5, + mm.queueMetrics(ActivityLogEntry{ + Model: "model1", + Tokens: TokenMetrics{ + InputTokens: 100, + OutputTokens: 50, + TokensPerSecond: 25.5, + }, }) - mm.addMetrics(TokenMetrics{ - Model: "model2", - InputTokens: 200, - OutputTokens: 100, - TokensPerSecond: 30.0, + mm.queueMetrics(ActivityLogEntry{ + Model: "model2", + Tokens: TokenMetrics{ + InputTokens: 200, + OutputTokens: 100, + TokensPerSecond: 30.0, + }, }) jsonData, err := mm.getMetricsJSON() assert.NoError(t, err) - var metrics []TokenMetrics + var metrics []ActivityLogEntry err = json.Unmarshal(jsonData, &metrics) assert.NoError(t, err) assert.Equal(t, 2, len(metrics)) @@ -190,14 +203,14 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].InputTokens) - assert.Equal(t, 50, metrics[0].OutputTokens) + assert.Equal(t, 100, metrics[0].Tokens.InputTokens) + assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) }) t.Run("successful request with timings data", func(t *testing.T) { @@ -226,17 +239,17 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].InputTokens) - assert.Equal(t, 50, metrics[0].OutputTokens) - assert.Equal(t, 20, metrics[0].CachedTokens) - assert.Equal(t, 150.5, metrics[0].PromptPerSecond) - assert.Equal(t, 25.5, metrics[0].TokensPerSecond) + assert.Equal(t, 100, metrics[0].Tokens.InputTokens) + assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) + assert.Equal(t, 20, metrics[0].Tokens.CachedTokens) + assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond) + assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond) assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500 }) @@ -265,18 +278,18 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) // When timings data is present, it takes precedence - assert.Equal(t, 10, metrics[0].InputTokens) - assert.Equal(t, 20, metrics[0].OutputTokens) + assert.Equal(t, 10, metrics[0].Tokens.InputTokens) + assert.Equal(t, 20, metrics[0].Tokens.OutputTokens) }) - t.Run("non-OK status code does not record metrics", func(t *testing.T) { + t.Run("non-OK status code records partial metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 0) nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { @@ -289,11 +302,16 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, "/test", metrics[0].ReqPath) + assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) t.Run("empty response body records minimal metrics", func(t *testing.T) { @@ -308,14 +326,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) t.Run("invalid JSON records minimal metrics", func(t *testing.T) { @@ -332,14 +350,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) // Errors after response is sent are logged, not returned metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) t.Run("next handler error is propagated", func(t *testing.T) { @@ -354,7 +372,7 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.Equal(t, expectedErr, err) metrics := mm.getMetrics() @@ -377,14 +395,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) t.Run("infill request extracts timings from last array element", func(t *testing.T) { @@ -416,17 +434,17 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 150, metrics[0].InputTokens) - assert.Equal(t, 75, metrics[0].OutputTokens) - assert.Equal(t, 30, metrics[0].CachedTokens) - assert.Equal(t, 200.5, metrics[0].PromptPerSecond) - assert.Equal(t, 35.5, metrics[0].TokensPerSecond) + assert.Equal(t, 150, metrics[0].Tokens.InputTokens) + assert.Equal(t, 75, metrics[0].Tokens.OutputTokens) + assert.Equal(t, 30, metrics[0].Tokens.CachedTokens) + assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond) + assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond) assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800 }) @@ -446,14 +464,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) } @@ -507,7 +525,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) { } func TestMetricsMonitor_Concurrent(t *testing.T) { - t.Run("concurrent addMetrics is safe", func(t *testing.T) { + t.Run("concurrent queueMetrics is safe", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 1000, 0) var wg sync.WaitGroup @@ -519,10 +537,12 @@ func TestMetricsMonitor_Concurrent(t *testing.T) { go func(id int) { defer wg.Done() for j := 0; j < metricsPerGoroutine; j++ { - mm.addMetrics(TokenMetrics{ - Model: "test-model", - InputTokens: id*1000 + j, - OutputTokens: j, + mm.queueMetrics(ActivityLogEntry{ + Model: "test-model", + Tokens: TokenMetrics{ + InputTokens: id*1000 + j, + OutputTokens: j, + }, }) } }(i) @@ -542,7 +562,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) { // Writer goroutine go func() { for i := 0; i < 50; i++ { - mm.addMetrics(TokenMetrics{Model: "test-model"}) + mm.queueMetrics(ActivityLogEntry{Model: "test-model"}) time.Sleep(1 * time.Millisecond) } done <- true @@ -586,10 +606,10 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) { metrics, err := parseMetrics("test-model", start, usage, timings) assert.NoError(t, err) - assert.Equal(t, 5, metrics.InputTokens) - assert.Equal(t, 1, metrics.OutputTokens) - assert.Equal(t, 10.0, metrics.PromptPerSecond) - assert.Equal(t, 2.0, metrics.TokensPerSecond) + assert.Equal(t, 5, metrics.Tokens.InputTokens) + assert.Equal(t, 1, metrics.Tokens.OutputTokens) + assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond) + assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond) assert.GreaterOrEqual(t, metrics.DurationMs, 5000) }) @@ -623,14 +643,14 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) // Should use timings values, not usage values - assert.Equal(t, 100, metrics[0].InputTokens) - assert.Equal(t, 50, metrics[0].OutputTokens) + assert.Equal(t, 100, metrics[0].Tokens.InputTokens) + assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) }) t.Run("handles missing cache_n in timings", func(t *testing.T) { @@ -658,12 +678,12 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) - assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present + assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present }) } @@ -693,13 +713,13 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 100, metrics[0].InputTokens) - assert.Equal(t, 50, metrics[0].OutputTokens) + assert.Equal(t, 100, metrics[0].Tokens.InputTokens) + assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) }) t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) { @@ -722,14 +742,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) t.Run("v1/responses format with nested response.usage", func(t *testing.T) { @@ -751,14 +771,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 17, metrics[0].InputTokens) - assert.Equal(t, 23, metrics[0].OutputTokens) + assert.Equal(t, 17, metrics[0].Tokens.InputTokens) + assert.Equal(t, 23, metrics[0].Tokens.OutputTokens) }) t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) { @@ -777,14 +797,14 @@ data: [DONE] rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) } @@ -792,20 +812,22 @@ data: [DONE] func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) { mm := newMetricsMonitor(testLogger, 1000, 0) - metric := TokenMetrics{ - Model: "test-model", - CachedTokens: 100, - InputTokens: 500, - OutputTokens: 250, - PromptPerSecond: 1200.5, - TokensPerSecond: 45.8, - DurationMs: 5000, - Timestamp: time.Now(), + metric := ActivityLogEntry{ + Model: "test-model", + Tokens: TokenMetrics{ + CachedTokens: 100, + InputTokens: 500, + OutputTokens: 250, + PromptPerSecond: 1200.5, + TokensPerSecond: 45.8, + }, + DurationMs: 5000, + Timestamp: time.Now(), } b.ResetTimer() for i := 0; i < b.N; i++ { - mm.addMetrics(metric) + mm.queueMetrics(metric) } } @@ -813,20 +835,22 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) { // Test performance with a smaller buffer where wrapping occurs more frequently mm := newMetricsMonitor(testLogger, 100, 0) - metric := TokenMetrics{ - Model: "test-model", - CachedTokens: 100, - InputTokens: 500, - OutputTokens: 250, - PromptPerSecond: 1200.5, - TokensPerSecond: 45.8, - DurationMs: 5000, - Timestamp: time.Now(), + metric := ActivityLogEntry{ + Model: "test-model", + Tokens: TokenMetrics{ + CachedTokens: 100, + InputTokens: 500, + OutputTokens: 250, + PromptPerSecond: 1200.5, + TokensPerSecond: 45.8, + }, + DurationMs: 5000, + Timestamp: time.Now(), } b.ResetTimer() for i := 0; i < b.N; i++ { - mm.addMetrics(metric) + mm.queueMetrics(metric) } } @@ -855,14 +879,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 100, metrics[0].InputTokens) - assert.Equal(t, 50, metrics[0].OutputTokens) + assert.Equal(t, 100, metrics[0].Tokens.InputTokens) + assert.Equal(t, 50, metrics[0].Tokens.OutputTokens) }) t.Run("deflate encoded response", func(t *testing.T) { @@ -889,14 +913,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 200, metrics[0].InputTokens) - assert.Equal(t, 75, metrics[0].OutputTokens) + assert.Equal(t, 200, metrics[0].Tokens.InputTokens) + assert.Equal(t, 75, metrics[0].Tokens.OutputTokens) }) t.Run("invalid gzip data records minimal metrics", func(t *testing.T) { @@ -917,14 +941,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) // Should not return error, just log warning metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) assert.Equal(t, "test-model", metrics[0].Model) - assert.Equal(t, 0, metrics[0].InputTokens) - assert.Equal(t, 0, metrics[0].OutputTokens) + assert.Equal(t, 0, metrics[0].Tokens.InputTokens) + assert.Equal(t, 0, metrics[0].Tokens.OutputTokens) }) t.Run("unknown encoding treated as uncompressed", func(t *testing.T) { @@ -944,13 +968,13 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) metrics := mm.getMetrics() assert.Equal(t, 1, len(metrics)) - assert.Equal(t, 300, metrics[0].InputTokens) - assert.Equal(t, 100, metrics[0].OutputTokens) + assert.Equal(t, 300, metrics[0].Tokens.InputTokens) + assert.Equal(t, 100, metrics[0].Tokens.OutputTokens) }) } @@ -989,7 +1013,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) { mm.addCapture(capture) // Should not store capture - assert.Nil(t, mm.getCaptureByID(0, false)) + assert.Nil(t, mm.getCaptureByID(0)) }) t.Run("adds capture when enabled", func(t *testing.T) { @@ -1002,22 +1026,18 @@ func TestMetricsMonitor_AddCapture(t *testing.T) { } mm.addCapture(capture) - retrieved := mm.getCaptureByID(0, true) - assert.NotNil(t, retrieved) - - var decoded ReqRespCapture - err := json.Unmarshal(retrieved, &decoded) - assert.NoError(t, err) - assert.Equal(t, 0, decoded.ID) - assert.Equal(t, []byte("test request"), decoded.ReqBody) - assert.Equal(t, []byte("test response"), decoded.RespBody) + captured := mm.getCaptureByID(0) + assert.NotNil(t, captured) + assert.Equal(t, 0, captured.ID) + assert.Equal(t, []byte("test request"), captured.ReqBody) + assert.Equal(t, []byte("test response"), captured.RespBody) }) t.Run("evicts oldest when exceeding max size", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) // Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes. // 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit. - mm.maxCaptureSize = 450 + mm.captureCache = cache.New(450) // Use random-looking data that doesn't compress well with zstd rng := rand.New(rand.NewSource(42)) @@ -1033,16 +1053,14 @@ func TestMetricsMonitor_AddCapture(t *testing.T) { // Adding capture3 should evict capture1 mm.addCapture(capture3) - assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted") - retrieved := mm.getCaptureByID(1, true) - assert.NotNil(t, retrieved, "capture 1 should exist") - retrieved = mm.getCaptureByID(2, true) - assert.NotNil(t, retrieved, "capture 2 should exist") + assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted") + assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist") + assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist") }) t.Run("skips capture larger than max size", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) - mm.maxCaptureSize = 100 + mm.captureCache = cache.New(100) // Use random data that doesn't compress well to create an oversized capture rng := rand.New(rand.NewSource(99)) @@ -1050,7 +1068,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) { rng.Read(largeCapture.ReqBody) mm.addCapture(largeCapture) - assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored") + assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored") }) } @@ -1058,7 +1076,7 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) { t.Run("returns nil for non-existent ID", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) - assert.Nil(t, mm.getCaptureByID(999, false)) + assert.Nil(t, mm.getCaptureByID(999)) }) t.Run("returns decompressed capture by ID", func(t *testing.T) { @@ -1071,18 +1089,14 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) { } mm.addCapture(capture) - retrieved := mm.getCaptureByID(42, true) - assert.NotNil(t, retrieved) - - var decoded ReqRespCapture - err := json.Unmarshal(retrieved, &decoded) - assert.NoError(t, err) - assert.Equal(t, 42, decoded.ID) - assert.Equal(t, []byte("test request"), decoded.ReqBody) - assert.Equal(t, []byte("test response"), decoded.RespBody) + captured := mm.getCaptureByID(42) + assert.NotNil(t, captured) + assert.Equal(t, 42, captured.ID) + assert.Equal(t, []byte("test request"), captured.ReqBody) + assert.Equal(t, []byte("test response"), captured.RespBody) }) - t.Run("returns compressed bytes when decompress=false", func(t *testing.T) { + t.Run("stores data as compressed bytes", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) capture := ReqRespCapture{ @@ -1092,10 +1106,12 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) { } mm.addCapture(capture) - compressed := mm.getCaptureByID(42, false) + compressed, exists := mm.getCompressedBytes(42) + assert.True(t, exists) assert.NotNil(t, compressed) - // Compressed data should not be valid JSON (it's zstd-compressed) - assert.False(t, gjson.ValidBytes(compressed)) + // Compressed data should not be valid CBOR (it's zstd-compressed) + var decoded ReqRespCapture + assert.Error(t, cbor.Unmarshal(compressed, &decoded)) }) } @@ -1164,7 +1180,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) // Check metric was recorded @@ -1173,12 +1189,8 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { metricID := metrics[0].ID // Check capture was stored with same ID (decompressed) - captureData := mm.getCaptureByID(metricID, true) - assert.NotNil(t, captureData) - - var capture ReqRespCapture - err = json.Unmarshal(captureData, &capture) - assert.NoError(t, err) + capture := mm.getCaptureByID(metricID) + assert.NotNil(t, capture) assert.Equal(t, metricID, capture.ID) assert.Equal(t, []byte(requestBody), capture.ReqBody) assert.Equal(t, []byte(responseBody), capture.RespBody) @@ -1206,7 +1218,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { rec := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(rec) - err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler) assert.NoError(t, err) // Metrics should still be recorded @@ -1214,7 +1226,168 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { assert.Equal(t, 1, len(metrics)) // But no capture - capture := mm.getCaptureByID(metrics[0].ID, false) - assert.Nil(t, capture) + assert.Nil(t, mm.getCaptureByID(metrics[0].ID)) + }) +} + +func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) { + requestBody := `{"model": "test"}` + responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Custom", "header-value") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + t.Run("only request headers", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer secret") + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) + assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) + assert.Nil(t, capture.ReqBody) + assert.Nil(t, capture.RespHeaders) + assert.Nil(t, capture.RespBody) + }) + + t.Run("only request body", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Nil(t, capture.ReqHeaders) + assert.Equal(t, []byte(requestBody), capture.ReqBody) + assert.Nil(t, capture.RespHeaders) + assert.Nil(t, capture.RespBody) + }) + + t.Run("only response headers", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Nil(t, capture.ReqHeaders) + assert.Nil(t, capture.ReqBody) + assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"]) + assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"]) + assert.Nil(t, capture.RespBody) + }) + + t.Run("only response body", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Nil(t, capture.ReqHeaders) + assert.Nil(t, capture.ReqBody) + assert.Nil(t, capture.RespHeaders) + assert.Equal(t, []byte(responseBody), capture.RespBody) + }) + + t.Run("captureReqAll", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer secret") + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) + assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) + assert.Equal(t, []byte(requestBody), capture.ReqBody) + assert.Nil(t, capture.RespHeaders) + assert.Nil(t, capture.RespBody) + }) + + t.Run("captureRespAll", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Nil(t, capture.ReqHeaders) + assert.Nil(t, capture.ReqBody) + assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"]) + assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"]) + assert.Equal(t, []byte(responseBody), capture.RespBody) + }) + + t.Run("no flags", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Nil(t, capture.ReqHeaders) + assert.Nil(t, capture.ReqBody) + assert.Nil(t, capture.RespHeaders) + assert.Nil(t, capture.RespBody) + }) + + t.Run("mixed flags req headers and resp body", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 100) + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer secret") + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler) + assert.NoError(t, err) + + capture := mm.getCaptureByID(mm.getMetrics()[0].ID) + assert.NotNil(t, capture) + assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"]) + assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) + assert.Nil(t, capture.ReqBody) + assert.Nil(t, capture.RespHeaders) + assert.Equal(t, []byte(responseBody), capture.RespBody) }) } diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index ee1d3484..6b27508d 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -332,41 +332,77 @@ func (pm *ProxyManager) setupGinEngine() { // Set up routes using the Gin engine // Protected routes use pm.apiKeyAuth() middleware - pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + llmHandler := pm.mkProxyJSONHandler(captureAll) + pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) + pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // Support legacy /v1/completions api, see issue #12 - pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) - pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // Support anthropic count_tokens API (Also added in the above PR) - pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // Support embeddings and reranking - pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // llama-server's /reranking endpoint + aliases - pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) + pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) + pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) + pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // llama-server's /infill endpoint for code infilling - pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // llama-server's /completion endpoint - pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler) // Support audio/speech endpoint - pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST( + "/v1/audio/speech", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders), + ) + pm.ginEngine.POST( + "/v1/audio/voices", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll), + ) pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) - pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler) - pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler) + + pm.ginEngine.POST( + "/v1/audio/transcriptions", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody), + ) + pm.ginEngine.POST( + "/v1/images/generations", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders), + ) + + pm.ginEngine.POST( + "/v1/images/edits", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders), + ) // sd.cpp /sdapi/v1 endpoints - pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) - pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/sdapi/v1/txt2img", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders), + ) + pm.ginEngine.POST("/sdapi/v1/img2img", + pm.apiKeyAuth(), + pm.trackInflight(), + pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders), + ) pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler) @@ -686,7 +722,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { // attempt to record metrics if it is a POST request if pm.metricsMonitor != nil && c.Request.Method == "POST" { - if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil { + if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath) return @@ -700,280 +736,294 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { } } -func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { - bodyBytes, err := io.ReadAll(c.Request.Body) - if err != nil { - pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") - return - } - - requestedModel := gjson.GetBytes(bodyBytes, "model").String() - if requestedModel == "" { - pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") - return - } - - // Look for a matching local model first - var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error - - modelID, found := pm.config.RealModelName(requestedModel) - if found { - var localHandler func(string, http.ResponseWriter, *http.Request) error - if pm.matrix != nil { - localHandler = pm.matrix.ProxyRequest - } else { - processGroup, err := pm.swapProcessGroup(modelID) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - localHandler = processGroup.ProxyRequest +func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) { + return func(c *gin.Context) { + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") + return } - // issue #69 allow custom model names to be sent to upstream - useModelName := pm.config.Models[modelID].UseModelName - if useModelName != "" { - bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) - return - } + requestedModel := gjson.GetBytes(bodyBytes, "model").String() + if requestedModel == "" { + pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") + return } - // issue #174 strip parameters from the JSON body - stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams() - if err != nil { // just log it and continue - pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error()) - } else { + // Look for a matching local model first + var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error + + modelID, found := pm.config.RealModelName(requestedModel) + if found { + var localHandler func(string, http.ResponseWriter, *http.Request) error + if pm.matrix != nil { + localHandler = pm.matrix.ProxyRequest + } else { + processGroup, err := pm.swapProcessGroup(modelID) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) + return + } + localHandler = processGroup.ProxyRequest + } + + // issue #69 allow custom model names to be sent to upstream + useModelName := pm.config.Models[modelID].UseModelName + if useModelName != "" { + bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) + return + } + } + + // issue #174 strip parameters from the JSON body + stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams() + if err != nil { // just log it and continue + pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error()) + } else { + for _, param := range stripParams { + pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param) + bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) + return + } + } + } + + // issue #453 set/override parameters in the JSON body + setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams() + for _, key := range setParamKeys { + pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key) + bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) + return + } + } + + // setParamsByID: set params based on the requested model ID (runs after setParams, can override it) + setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel) + for _, key := range setParamsByIDKeys { + pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key) + bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key]) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) + return + } + } + + pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) + nextHandler = localHandler + } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { + pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) + modelID = requestedModel + + // issue #453 apply filters for peer requests + peerFilters := pm.peerProxy.GetPeerFilters(requestedModel) + + // Apply stripParams - remove specified parameters from request + stripParams := peerFilters.SanitizedStripParams() for _, param := range stripParams { - pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param) + pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param) bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param)) + return + } + } + + // Apply setParams - set/override specified parameters in request + setParams, setParamKeys := peerFilters.SanitizedSetParams() + for _, key := range setParamKeys { + pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key) + bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) + return + } + } + + nextHandler = pm.peerProxy.ProxyRequest + } + + if nextHandler == nil { + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel)) + return + } + + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // dechunk it as we already have all the body bytes see issue #11 + c.Request.Header.Del("transfer-encoding") + c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) + c.Request.ContentLength = int64(len(bodyBytes)) + + // issue #366 extract values that downstream handlers may need + isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool() + ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming) + ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID) + c.Request = c.Request.WithContext(ctx) + + if pm.metricsMonitor != nil && c.Request.Method == "POST" { + if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID) + return + } + } else { + if err := nextHandler(modelID, c.Writer, c.Request); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) + return + } + } + } +} + +// mkPostFormHandler creates a POST form handler for inference backends +// with a custom captureFields to filter out large binary requests or responses. +func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) { + return func(c *gin.Context) { + // Parse multipart form + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) + return + } + + // Get model parameter from the form + requestedModel := c.Request.FormValue("model") + if requestedModel == "" { + pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data") + return + } + + // Look for a matching local model first, then check peers + var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error + var useModelName string + + modelID, found := pm.config.RealModelName(requestedModel) + if found { + if pm.matrix != nil { + nextHandler = pm.matrix.ProxyRequest + } else { + processGroup, err := pm.swapProcessGroup(modelID) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) + return + } + nextHandler = processGroup.ProxyRequest + } + + useModelName = pm.config.Models[modelID].UseModelName + pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) + } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { + pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) + modelID = requestedModel + nextHandler = pm.peerProxy.ProxyRequest + } + + if nextHandler == nil { + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel)) + return + } + + // We need to reconstruct the multipart form in any case since the body is consumed + // Create a new buffer for the reconstructed request + var requestBuffer bytes.Buffer + multipartWriter := multipart.NewWriter(&requestBuffer) + + // Copy all form values + for key, values := range c.Request.MultipartForm.Value { + for _, value := range values { + fieldValue := value + // If this is the model field and we have a profile, use just the model name + if key == "model" { + // # issue #69 allow custom model names to be sent to upstream + if useModelName != "" { + fieldValue = useModelName + } else { + fieldValue = requestedModel + } + } + field, err := multipartWriter.CreateFormField(key) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field") + return + } + if _, err = field.Write([]byte(fieldValue)); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field") return } } } - // issue #453 set/override parameters in the JSON body - setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams() - for _, key := range setParamKeys { - pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key) - bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) - return - } - } - - // setParamsByID: set params based on the requested model ID (runs after setParams, can override it) - setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel) - for _, key := range setParamsByIDKeys { - pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key) - bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key]) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) - return - } - } - - pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) - nextHandler = localHandler - } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { - pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) - modelID = requestedModel - - // issue #453 apply filters for peer requests - peerFilters := pm.peerProxy.GetPeerFilters(requestedModel) - - // Apply stripParams - remove specified parameters from request - stripParams := peerFilters.SanitizedStripParams() - for _, param := range stripParams { - pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param) - bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param)) - return - } - } - - // Apply setParams - set/override specified parameters in request - setParams, setParamKeys := peerFilters.SanitizedSetParams() - for _, key := range setParamKeys { - pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key) - bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) - return - } - } - - nextHandler = pm.peerProxy.ProxyRequest - } - - if nextHandler == nil { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel)) - return - } - - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - // dechunk it as we already have all the body bytes see issue #11 - c.Request.Header.Del("transfer-encoding") - c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) - c.Request.ContentLength = int64(len(bodyBytes)) - - // issue #366 extract values that downstream handlers may need - isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool() - ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming) - ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID) - c.Request = c.Request.WithContext(ctx) - - if pm.metricsMonitor != nil && c.Request.Method == "POST" { - if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID) - return - } - } else { - if err := nextHandler(modelID, c.Writer, c.Request); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) - return - } - } -} - -func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { - // Parse multipart form - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) - return - } - - // Get model parameter from the form - requestedModel := c.Request.FormValue("model") - if requestedModel == "" { - pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data") - return - } - - // Look for a matching local model first, then check peers - var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error - var useModelName string - - modelID, found := pm.config.RealModelName(requestedModel) - if found { - if pm.matrix != nil { - nextHandler = pm.matrix.ProxyRequest - } else { - processGroup, err := pm.swapProcessGroup(modelID) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - nextHandler = processGroup.ProxyRequest - } - - useModelName = pm.config.Models[modelID].UseModelName - pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) - } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { - pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) - modelID = requestedModel - nextHandler = pm.peerProxy.ProxyRequest - } - - if nextHandler == nil { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel)) - return - } - - // We need to reconstruct the multipart form in any case since the body is consumed - // Create a new buffer for the reconstructed request - var requestBuffer bytes.Buffer - multipartWriter := multipart.NewWriter(&requestBuffer) - - // Copy all form values - for key, values := range c.Request.MultipartForm.Value { - for _, value := range values { - fieldValue := value - // If this is the model field and we have a profile, use just the model name - if key == "model" { - // # issue #69 allow custom model names to be sent to upstream - if useModelName != "" { - fieldValue = useModelName - } else { - fieldValue = requestedModel + // Copy all files from the original request + for key, fileHeaders := range c.Request.MultipartForm.File { + for _, fileHeader := range fileHeaders { + formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file") + return } - } - field, err := multipartWriter.CreateFormField(key) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field") - return - } - if _, err = field.Write([]byte(fieldValue)); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field") - return - } - } - } - // Copy all files from the original request - for key, fileHeaders := range c.Request.MultipartForm.File { - for _, fileHeader := range fileHeaders { - formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file") - return - } + file, err := fileHeader.Open() + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") + return + } - file, err := fileHeader.Open() - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") - return - } - - if _, err = io.Copy(formFile, file); err != nil { + if _, err = io.Copy(formFile, file); err != nil { + file.Close() + pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") + return + } file.Close() - pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") + } + } + + // Close the multipart writer to finalize the form + if err := multipartWriter.Close(); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form") + return + } + + // Create a new request with the reconstructed form data + modifiedReq, err := http.NewRequestWithContext( + c.Request.Context(), + c.Request.Method, + c.Request.URL.String(), + &requestBuffer, + ) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request") + return + } + + // Copy the headers from the original request + modifiedReq.Header = c.Request.Header.Clone() + modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) + + // set the content length of the body + modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len())) + modifiedReq.ContentLength = int64(requestBuffer.Len()) + + // Use the modified request for proxying + if pm.metricsMonitor != nil { + if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) + return + } + } else { + if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) return } - file.Close() } } - - // Close the multipart writer to finalize the form - if err := multipartWriter.Close(); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form") - return - } - - // Create a new request with the reconstructed form data - modifiedReq, err := http.NewRequestWithContext( - c.Request.Context(), - c.Request.Method, - c.Request.URL.String(), - &requestBuffer, - ) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request") - return - } - - // Copy the headers from the original request - modifiedReq.Header = c.Request.Header.Clone() - modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) - - // set the content length of the body - modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len())) - modifiedReq.ContentLength = int64(requestBuffer.Len()) - - // Use the modified request for proxying - if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) - return - } } func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) { diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index 1f46d3ac..7942b8aa 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -158,7 +158,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) { } } - sendMetrics := func(metrics []TokenMetrics) { + sendMetrics := func(metrics []ActivityLogEntry) { jsonData, err := json.Marshal(metrics) if err == nil { select { @@ -205,8 +205,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) { /** * Send Metrics data */ - defer event.On(func(e TokenMetricsEvent) { - sendMetrics([]TokenMetrics{e.Metrics}) + defer event.On(func(e ActivityLogEvent) { + sendMetrics([]ActivityLogEntry{e.Metrics}) })() /** @@ -290,26 +290,16 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) { return } - data, exists := pm.metricsMonitor.getCompressedBytes(id) - if !exists { + capture := pm.metricsMonitor.getCaptureByID(id) + if capture == nil { c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"}) return } - c.Header("Vary", "Accept-Encoding") - - // ¯\_(ツ)_/¯ quality weights are too fancy for us anyway - hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd") - - if hasZstd { - c.Header("Content-Encoding", "zstd") - c.Data(http.StatusOK, "application/json", data) - } else { - decompressed, err := decompressCapture(data) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"}) - return - } - c.Data(http.StatusOK, "application/json", decompressed) + jsonBytes, err := json.Marshal(capture) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"}) + return } + c.Data(http.StatusOK, "application/json", jsonBytes) } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index a5dbfbc4..041c013b 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -1721,3 +1721,61 @@ models: assert.Contains(t, w.Body.String(), "could not find suitable handler") }) } + +func TestProxyManager_AudioTranscriptionCapture(t *testing.T) { + cfg := testConfigFromYAML(t, ` +healthCheckTimeout: 15 +logLevel: error +captureBuffer: 5 +models: + TheExpectedModel: + cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel +`) + + proxy := New(cfg) + defer proxy.StopProcesses(StopWaitForInflightRequest) + injectTestHandlers(proxy, nil) + + var b bytes.Buffer + w := multipart.NewWriter(&b) + + fw, err := w.CreateFormField("model") + assert.NoError(t, err) + _, err = fw.Write([]byte("TheExpectedModel")) + assert.NoError(t, err) + + fw, err = w.CreateFormFile("file", "test.mp3") + assert.NoError(t, err) + _, err = fw.Write([]byte("test audio content")) + assert.NoError(t, err) + w.Close() + + req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) + req.Header.Set("Content-Type", w.FormDataContentType()) + req.Header.Set("Authorization", "Bearer mysecret") + req.Header.Set("X-Custom-Req", "req-value") + rec := CreateTestResponseRecorder() + proxy.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + // Verify capture exists + metrics := proxy.metricsMonitor.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.True(t, metrics[0].HasCapture) + + capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID) + assert.NotNil(t, capture) + + // Should capture request headers (sensitive ones redacted) + assert.NotEmpty(t, capture.ReqHeaders) + assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"]) + assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"]) + + // Should capture response headers + assert.NotNil(t, capture.RespHeaders) + + // Should NOT capture request bodies but get response bodies (text + assert.Nil(t, capture.ReqBody) + assert.NotNil(t, capture.RespBody) +} diff --git a/ui-svelte/package-lock.json b/ui-svelte/package-lock.json index d41cd547..11674498 100644 --- a/ui-svelte/package-lock.json +++ b/ui-svelte/package-lock.json @@ -2788,9 +2788,9 @@ } }, "node_modules/postcss": { - "version": "8.5.8", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz", - "integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==", + "version": "8.5.12", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz", + "integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==", "dev": true, "funding": [ { diff --git a/ui-svelte/src/components/ActivityStats.svelte b/ui-svelte/src/components/ActivityStats.svelte index 316d2abe..870ecdc9 100644 --- a/ui-svelte/src/components/ActivityStats.svelte +++ b/ui-svelte/src/components/ActivityStats.svelte @@ -9,13 +9,13 @@ let stats = $derived.by(() => { const totalRequests = $metrics.length; - const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0); - const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0); - const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.cache_tokens, 0); + const totalInputTokens = $metrics.reduce((sum, m) => sum + m.tokens.input_tokens, 0); + const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.tokens.output_tokens, 0); + const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.tokens.cache_tokens, 0); - const promptPerSecond = $metrics.filter((m) => m.prompt_per_second > 0).map((m) => m.prompt_per_second); + const promptPerSecond = $metrics.filter((m) => m.tokens.prompt_per_second > 0).map((m) => m.tokens.prompt_per_second); - const tokensPerSecond = $metrics.filter((m) => m.tokens_per_second > 0).map((m) => m.tokens_per_second); + const tokensPerSecond = $metrics.filter((m) => m.tokens.tokens_per_second > 0).map((m) => m.tokens.tokens_per_second); const promptHistogramData = promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null; diff --git a/ui-svelte/src/lib/types.ts b/ui-svelte/src/lib/types.ts index f15f20a9..b5a1a528 100644 --- a/ui-svelte/src/lib/types.ts +++ b/ui-svelte/src/lib/types.ts @@ -12,15 +12,22 @@ export interface Model { aliases?: string[]; } -export interface Metrics { - id: number; - timestamp: string; - model: string; +export interface TokenMetrics { cache_tokens: number; input_tokens: number; output_tokens: number; prompt_per_second: number; tokens_per_second: number; +} + +export interface ActivityLogEntry { + id: number; + timestamp: string; + model: string; + req_path: string; + resp_content_type: string; + resp_status_code: number; + tokens: TokenMetrics; duration_ms: number; has_capture: boolean; } diff --git a/ui-svelte/src/routes/Activity.svelte b/ui-svelte/src/routes/Activity.svelte index 3790cf70..3d59c513 100644 --- a/ui-svelte/src/routes/Activity.svelte +++ b/ui-svelte/src/routes/Activity.svelte @@ -3,8 +3,87 @@ import ActivityStats from "../components/ActivityStats.svelte"; import Tooltip from "../components/Tooltip.svelte"; import CaptureDialog from "../components/CaptureDialog.svelte"; + import { persistentStore } from "../stores/persistent"; + import { onMount } from "svelte"; import type { ReqRespCapture } from "../lib/types"; + type ColumnKey = + | "id" + | "time" + | "model" + | "req_path" + | "resp_status_code" + | "resp_content_type" + | "cached" + | "prompt" + | "generated" + | "prompt_speed" + | "gen_speed" + | "duration" + | "capture"; + + interface ColumnDef { + key: ColumnKey; + label: string; + defaultVisible: boolean; + } + + const columns: ColumnDef[] = [ + { key: "id", label: "ID", defaultVisible: true }, + { key: "time", label: "Time", defaultVisible: true }, + { key: "model", label: "Model", defaultVisible: true }, + { key: "req_path", label: "Path", defaultVisible: false }, + { key: "resp_status_code", label: "Status", defaultVisible: false }, + { key: "resp_content_type", label: "Content-Type", defaultVisible: false }, + { key: "cached", label: "Cached", defaultVisible: true }, + { key: "prompt", label: "Prompt", defaultVisible: true }, + { key: "generated", label: "Generated", defaultVisible: true }, + { key: "prompt_speed", label: "Prompt Speed", defaultVisible: true }, + { key: "gen_speed", label: "Gen Speed", defaultVisible: true }, + { key: "duration", label: "Duration", defaultVisible: true }, + { key: "capture", label: "Capture", defaultVisible: true }, + ]; + + const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key); + + const visibleColumns = persistentStore( + "activity-columns", + defaultVisibleKeys + ); + + let columnsMenuOpen = $state(false); + let dropdownContainer: HTMLDivElement | null = null; + + onMount(() => { + function handleKeydown(e: KeyboardEvent) { + if (e.key === "Escape" && columnsMenuOpen) { + columnsMenuOpen = false; + } + } + function handleClick(e: MouseEvent) { + if (columnsMenuOpen && dropdownContainer && !dropdownContainer.contains(e.target as Node)) { + columnsMenuOpen = false; + } + } + document.addEventListener("keydown", handleKeydown); + document.addEventListener("click", handleClick); + return () => { + document.removeEventListener("keydown", handleKeydown); + document.removeEventListener("click", handleClick); + }; + }); + + function toggleColumn(key: ColumnKey) { + const current = $visibleColumns; + if (current.includes(key)) { + if (current.length > 1) { + visibleColumns.set(current.filter((k) => k !== key)); + } + } else { + visibleColumns.set([...current, key]); + } + } + function formatSpeed(speed: number): string { return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s"; } @@ -67,58 +146,150 @@ -
+
+
+
+ + {#if columnsMenuOpen} +
+
+ Columns +
+ {#each columns as col (col.key)} + + {/each} +
+ {/if} +
+
+ - - - - - - - - - - + {#if $visibleColumns.includes("id")} + + {/if} + {#if $visibleColumns.includes("time")} + + {/if} + {#if $visibleColumns.includes("model")} + + {/if} + {#if $visibleColumns.includes("req_path")} + + {/if} + {#if $visibleColumns.includes("resp_status_code")} + + {/if} + {#if $visibleColumns.includes("resp_content_type")} + + {/if} + {#if $visibleColumns.includes("cached")} + + {/if} + {#if $visibleColumns.includes("prompt")} + + {/if} + {#if $visibleColumns.includes("generated")} + + {/if} + {#if $visibleColumns.includes("prompt_speed")} + + {/if} + {#if $visibleColumns.includes("gen_speed")} + + {/if} + {#if $visibleColumns.includes("duration")} + + {/if} + {#if $visibleColumns.includes("capture")} + + {/if} {#if sortedMetrics.length === 0} - {:else} {#each sortedMetrics as metric (metric.id)} - - - - - - - - - - + {#if $visibleColumns.includes("id")} + + {/if} + {#if $visibleColumns.includes("time")} + + {/if} + {#if $visibleColumns.includes("model")} + + {/if} + {#if $visibleColumns.includes("req_path")} + + {/if} + {#if $visibleColumns.includes("resp_status_code")} + + {/if} + {#if $visibleColumns.includes("resp_content_type")} + + {/if} + {#if $visibleColumns.includes("cached")} + + {/if} + {#if $visibleColumns.includes("prompt")} + + {/if} + {#if $visibleColumns.includes("generated")} + + {/if} + {#if $visibleColumns.includes("prompt_speed")} + + {/if} + {#if $visibleColumns.includes("gen_speed")} + + {/if} + {#if $visibleColumns.includes("duration")} + + {/if} + {#if $visibleColumns.includes("capture")} + + {/if} {/each} {/if} diff --git a/ui-svelte/src/routes/LogViewer.svelte b/ui-svelte/src/routes/LogViewer.svelte index 2643c202..002de24f 100644 --- a/ui-svelte/src/routes/LogViewer.svelte +++ b/ui-svelte/src/routes/LogViewer.svelte @@ -10,7 +10,7 @@ const viewModeStore = persistentStore("logviewer-view-mode", "panels"); let direction = $derived<"horizontal" | "vertical">( - $screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal" + $screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal", ); @@ -30,7 +30,7 @@ class:bg-primary={$viewModeStore === "proxy"} class:text-btn-primary-text={$viewModeStore === "proxy"} > - Panel + Proxy
IDTimeModel - Cached - - Prompt - GeneratedPrompt ProcessingGeneration SpeedDurationCaptureIDTimeModelPathStatusContent-Type + Cached + + Prompt + GeneratedPrompt SpeedGen SpeedDurationCapture
+ No activity recorded
{metric.id + 1}{formatRelativeTime(metric.timestamp)}{metric.model}{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}{metric.input_tokens.toLocaleString()}{metric.output_tokens.toLocaleString()}{formatSpeed(metric.prompt_per_second)}{formatSpeed(metric.tokens_per_second)}{formatDuration(metric.duration_ms)} - {#if metric.has_capture} - - {:else} - - - {/if} - {metric.id + 1}{formatRelativeTime(metric.timestamp)}{metric.model}{metric.req_path || "-"}{metric.resp_status_code || "-"}{metric.resp_content_type || "-"}{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}{metric.tokens.input_tokens.toLocaleString()}{metric.tokens.output_tokens.toLocaleString()}{formatSpeed(metric.tokens.prompt_per_second)}{formatSpeed(metric.tokens.tokens_per_second)}{formatDuration(metric.duration_ms)} + {#if metric.has_capture} + + {:else} + - + {/if} +