diff --git a/go.mod b/go.mod index bd24ffc3..8b679d82 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/billziss-gh/golib v0.2.0 github.com/fsnotify/fsnotify v1.9.0 github.com/gin-gonic/gin v1.10.0 + github.com/klauspost/compress v1.18.5 github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 4f0b2d2b..348a72cf 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go index 661bb9eb..283c1662 100644 --- a/proxy/metrics_monitor.go +++ b/proxy/metrics_monitor.go @@ -13,10 +13,54 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" "github.com/mostlygeek/llama-swap/event" "github.com/tidwall/gjson" ) +// zstdEncOptions are the shared zstd encoder options for maximum compression. +var zstdEncOptions = []zstd.EOption{ + zstd.WithEncoderLevel(zstd.SpeedBetterCompression), +} + +// zstdDecOptions are the shared zstd decoder options. +var zstdDecOptions = []zstd.DOption{} + +// zstdEncPool pools zstd.Encoder instances to reduce allocations. +var zstdEncPool = &sync.Pool{ + New: func() interface{} { + enc, _ := zstd.NewWriter(nil, zstdEncOptions...) + return enc + }, +} + +// zstdDecPool pools zstd.Decoder instances to reduce allocations. +var zstdDecPool = &sync.Pool{ + New: func() interface{} { + dec, _ := zstd.NewReader(nil, zstdDecOptions...) + return dec + }, +} + +// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd. +// Returns compressed bytes and the original JSON byte count for logging. +func compressCapture(c *ReqRespCapture) ([]byte, int, error) { + jsonBytes, err := json.Marshal(c) + if err != nil { + return nil, 0, fmt.Errorf("marshal capture: %w", err) + } + enc := zstdEncPool.Get().(*zstd.Encoder) + defer zstdEncPool.Put(enc) + return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil +} + +// decompressCapture decompresses zstd-compressed JSON and returns it. +func decompressCapture(data []byte) ([]byte, error) { + dec := zstdDecPool.Get().(*zstd.Decoder) + defer zstdDecPool.Put(dec) + return dec.DecodeAll(data, nil) +} + // TokenMetrics represents parsed token statistics from llama-server logs type TokenMetrics struct { ID int `json:"id"` @@ -40,18 +84,6 @@ type ReqRespCapture struct { RespBody []byte `json:"resp_body"` } -// Size returns the approximate memory usage of this capture in bytes -func (c *ReqRespCapture) Size() int { - size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody) - for k, v := range c.ReqHeaders { - size += len(k) + len(v) - } - for k, v := range c.RespHeaders { - size += len(k) + len(v) - } - return size -} - // TokenMetricsEvent represents a token metrics event type TokenMetricsEvent struct { Metrics TokenMetrics @@ -71,10 +103,10 @@ type metricsMonitor struct { // capture fields enableCaptures bool - captures map[int]ReqRespCapture // map for O(1) lookup by ID - captureOrder []int // track insertion order for FIFO eviction - captureSize int // current total size in bytes - maxCaptureSize int // max bytes for captures + captures map[int][]byte // zstd-compressed JSON of ReqRespCapture + captureOrder []int // track insertion order for FIFO eviction + captureSize int // current total compressed size in bytes + maxCaptureSize int // max bytes for captures (uncompressed) } // newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the @@ -84,7 +116,7 @@ func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) logger: logger, maxMetrics: maxMetrics, enableCaptures: captureBufferMB > 0, - captures: make(map[int]ReqRespCapture), + captures: make(map[int][]byte), captureOrder: make([]int, 0), captureSize: 0, maxCaptureSize: captureBufferMB * 1024 * 1024, @@ -108,45 +140,80 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int { } // addCapture adds a new capture to the buffer with size-based eviction. -// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize. +// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize. func (mp *metricsMonitor) addCapture(capture ReqRespCapture) { if !mp.enableCaptures { return } - mp.mu.Lock() - defer mp.mu.Unlock() - - captureSize := capture.Size() - if captureSize > mp.maxCaptureSize { - mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize) + compressed, uncompressedBytes, err := compressCapture(&capture) + if err != nil { + mp.logger.Warnf("failed to compress capture: %v, skipping", err) return } - // Evict oldest (FIFO) until room available + captureSize := len(compressed) + if captureSize > mp.maxCaptureSize { + mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize) + return + } + + compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100 + + mp.mu.Lock() + defer mp.mu.Unlock() + + // Evict oldest (FIFO) until room available for the compressed data for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 { oldestID := mp.captureOrder[0] mp.captureOrder = mp.captureOrder[1:] if evicted, exists := mp.captures[oldestID]; exists { - mp.captureSize -= evicted.Size() + l := len(evicted) + mp.captureSize -= l delete(mp.captures, oldestID) + mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l) } } - mp.captures[capture.ID] = capture + mp.captures[capture.ID] = compressed mp.captureOrder = append(mp.captureOrder, capture.ID) mp.captureSize += captureSize + + mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio) } -// getCaptureByID returns a capture by its ID, or nil if not found. -func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture { +// getCompressedBytes returns the raw compressed bytes for a capture by ID. +func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) { mp.mu.RLock() defer mp.mu.RUnlock() - if capture, exists := mp.captures[id]; exists { - return &capture + data, exists := mp.captures[id] + return data, exists +} + +// getCaptureByID returns decompressed capture bytes if found and decompress=true. +// If decompress=false, returns the raw zstd-compressed bytes. +// Returns nil if the capture is not found. +func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte { + mp.mu.RLock() + defer mp.mu.RUnlock() + + data, exists := mp.captures[id] + if !exists { + return nil } - return nil + + if !decompress { + return data + } + + decompressed, err := decompressCapture(data) + if err != nil { + mp.logger.Warnf("failed to decompress capture %d: %v", id, err) + return nil + } + + return decompressed } // getMetrics returns a copy of the current metrics @@ -290,8 +357,8 @@ func (mp *metricsMonitor) wrapHandler( RespHeaders: respHeaders, RespBody: body, } - // Only set HasCapture if the capture will actually be stored (not too large) - if capture.Size() <= mp.maxCaptureSize { + compressed, _, err := compressCapture(capture) + if err == nil && len(compressed) <= mp.maxCaptureSize { tm.HasCapture = true } } diff --git a/proxy/metrics_monitor_test.go b/proxy/metrics_monitor_test.go index b8c66a17..48372d9e 100644 --- a/proxy/metrics_monitor_test.go +++ b/proxy/metrics_monitor_test.go @@ -5,6 +5,7 @@ import ( "compress/flate" "compress/gzip" "encoding/json" + "math/rand" "net/http" "net/http/httptest" "sync" @@ -953,28 +954,27 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { }) } -func TestReqRespCapture_Size(t *testing.T) { - t.Run("calculates size correctly", func(t *testing.T) { +func TestReqRespCapture_CompressedSize(t *testing.T) { + t.Run("compressed size is smaller than uncompressed", func(t *testing.T) { capture := ReqRespCapture{ - ID: 1, - ReqPath: "/v1/chat/completions", // 20 bytes - ReqHeaders: map[string]string{ - "Content-Type": "application/json", // 12 + 16 = 28 - }, - ReqBody: []byte("request body"), // 12 bytes - RespHeaders: map[string]string{ - "X-Test": "value", // 6 + 5 = 11 - }, - RespBody: []byte("response body"), // 13 bytes + ID: 1, + ReqPath: "/v1/chat/completions", + ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`), + RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`), } - // Expected: 20 + 12 + 13 + 28 + 11 = 84 - assert.Equal(t, 84, capture.Size()) + compressed, uncompressed, err := compressCapture(&capture) + assert.NoError(t, err) + assert.Greater(t, uncompressed, 0) + assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed) }) - t.Run("handles empty capture", func(t *testing.T) { + t.Run("empty capture produces compressed output", func(t *testing.T) { capture := ReqRespCapture{} - assert.Equal(t, 0, capture.Size()) + compressed, _, err := compressCapture(&capture) + assert.NoError(t, err) + assert.NotNil(t, compressed) + assert.True(t, len(compressed) > 0) }) } @@ -989,7 +989,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) { mm.addCapture(capture) // Should not store capture - assert.Nil(t, mm.getCaptureByID(0)) + assert.Nil(t, mm.getCaptureByID(0, false)) }) t.Run("adds capture when enabled", func(t *testing.T) { @@ -1002,41 +1002,55 @@ func TestMetricsMonitor_AddCapture(t *testing.T) { } mm.addCapture(capture) - retrieved := mm.getCaptureByID(0) + retrieved := mm.getCaptureByID(0, true) assert.NotNil(t, retrieved) - assert.Equal(t, 0, retrieved.ID) - assert.Equal(t, []byte("test request"), retrieved.ReqBody) - assert.Equal(t, []byte("test response"), retrieved.RespBody) + + var decoded ReqRespCapture + err := json.Unmarshal(retrieved, &decoded) + assert.NoError(t, err) + assert.Equal(t, 0, decoded.ID) + assert.Equal(t, []byte("test request"), decoded.ReqBody) + assert.Equal(t, []byte("test response"), decoded.RespBody) }) t.Run("evicts oldest when exceeding max size", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) - mm.maxCaptureSize = 100 // Set small limit for test + // Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes. + // 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit. + mm.maxCaptureSize = 450 - // Add captures that will exceed the limit - capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)} - capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)} - capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)} + // Use random-looking data that doesn't compress well with zstd + rng := rand.New(rand.NewSource(42)) + capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)} + rng.Read(capture1.ReqBody) + capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)} + rng.Read(capture2.ReqBody) + capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)} + rng.Read(capture3.ReqBody) mm.addCapture(capture1) mm.addCapture(capture2) // Adding capture3 should evict capture1 mm.addCapture(capture3) - assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted") - assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist") - assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist") + assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted") + retrieved := mm.getCaptureByID(1, true) + assert.NotNil(t, retrieved, "capture 1 should exist") + retrieved = mm.getCaptureByID(2, true) + assert.NotNil(t, retrieved, "capture 2 should exist") }) t.Run("skips capture larger than max size", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) mm.maxCaptureSize = 100 - // Add a capture larger than max - largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)} + // Use random data that doesn't compress well to create an oversized capture + rng := rand.New(rand.NewSource(99)) + largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)} + rng.Read(largeCapture.ReqBody) mm.addCapture(largeCapture) - assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored") + assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored") }) } @@ -1044,21 +1058,44 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) { t.Run("returns nil for non-existent ID", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) - assert.Nil(t, mm.getCaptureByID(999)) + assert.Nil(t, mm.getCaptureByID(999, false)) }) - t.Run("returns capture by ID", func(t *testing.T) { + t.Run("returns decompressed capture by ID", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10, 5) capture := ReqRespCapture{ - ID: 42, - ReqBody: []byte("test"), + ID: 42, + ReqBody: []byte("test request"), + RespBody: []byte("test response"), } mm.addCapture(capture) - retrieved := mm.getCaptureByID(42) + retrieved := mm.getCaptureByID(42, true) assert.NotNil(t, retrieved) - assert.Equal(t, 42, retrieved.ID) + + var decoded ReqRespCapture + err := json.Unmarshal(retrieved, &decoded) + assert.NoError(t, err) + assert.Equal(t, 42, decoded.ID) + assert.Equal(t, []byte("test request"), decoded.ReqBody) + assert.Equal(t, []byte("test response"), decoded.RespBody) + }) + + t.Run("returns compressed bytes when decompress=false", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10, 5) + + capture := ReqRespCapture{ + ID: 42, + ReqBody: []byte("test request body"), + RespBody: []byte("test response body"), + } + mm.addCapture(capture) + + compressed := mm.getCaptureByID(42, false) + assert.NotNil(t, compressed) + // Compressed data should not be valid JSON (it's zstd-compressed) + assert.False(t, gjson.ValidBytes(compressed)) }) } @@ -1135,9 +1172,13 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { assert.Equal(t, 1, len(metrics)) metricID := metrics[0].ID - // Check capture was stored with same ID - capture := mm.getCaptureByID(metricID) - assert.NotNil(t, capture) + // Check capture was stored with same ID (decompressed) + captureData := mm.getCaptureByID(metricID, true) + assert.NotNil(t, captureData) + + var capture ReqRespCapture + err = json.Unmarshal(captureData, &capture) + assert.NoError(t, err) assert.Equal(t, metricID, capture.ID) assert.Equal(t, []byte(requestBody), capture.ReqBody) assert.Equal(t, []byte(responseBody), capture.RespBody) @@ -1173,7 +1214,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) { assert.Equal(t, 1, len(metrics)) // But no capture - capture := mm.getCaptureByID(metrics[0].ID) + capture := mm.getCaptureByID(metrics[0].ID, false) assert.Nil(t, capture) }) } diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index ba0506f7..1f46d3ac 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -290,11 +290,26 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) { return } - capture := pm.metricsMonitor.getCaptureByID(id) - if capture == nil { + data, exists := pm.metricsMonitor.getCompressedBytes(id) + if !exists { c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"}) return } - c.JSON(http.StatusOK, capture) + c.Header("Vary", "Accept-Encoding") + + // ¯\_(ツ)_/¯ quality weights are too fancy for us anyway + hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd") + + if hasZstd { + c.Header("Content-Encoding", "zstd") + c.Data(http.StatusOK, "application/json", data) + } else { + decompressed, err := decompressCapture(data) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"}) + return + } + c.Data(http.StatusOK, "application/json", decompressed) + } }