proxy: compress captures with zstd (#668)

The previous captures were saved uncompressed in memory. In agentic
workflows there can be many turns with each request containing the
previous context in the body with a lot of redundant data. Use zstd to
compress the request and response data before keeping a copy of memory.

Results: 

- Average Percentage Saved: 73.19%
- Average Compression Factor: ~6.77:1
This commit is contained in:
Benson Wong
2026-04-17 23:29:37 -07:00
committed by GitHub
parent c3f0d43e6e
commit 5e3c646829
5 changed files with 205 additions and 79 deletions
+1
View File
@@ -6,6 +6,7 @@ require (
github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
+2
View File
@@ -34,6 +34,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
+99 -32
View File
@@ -13,10 +13,54 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/mostlygeek/llama-swap/event"
"github.com/tidwall/gjson"
)
// zstdEncOptions are the shared zstd encoder options for maximum compression.
var zstdEncOptions = []zstd.EOption{
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
}
// zstdDecOptions are the shared zstd decoder options.
var zstdDecOptions = []zstd.DOption{}
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
var zstdEncPool = &sync.Pool{
New: func() interface{} {
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
return enc
},
}
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
var zstdDecPool = &sync.Pool{
New: func() interface{} {
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
return dec
},
}
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
// Returns compressed bytes and the original JSON byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
jsonBytes, err := json.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
enc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(enc)
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
}
// decompressCapture decompresses zstd-compressed JSON and returns it.
func decompressCapture(data []byte) ([]byte, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
return dec.DecodeAll(data, nil)
}
// TokenMetrics represents parsed token statistics from llama-server logs
type TokenMetrics struct {
ID int `json:"id"`
@@ -40,18 +84,6 @@ type ReqRespCapture struct {
RespBody []byte `json:"resp_body"`
}
// Size returns the approximate memory usage of this capture in bytes
func (c *ReqRespCapture) Size() int {
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
for k, v := range c.ReqHeaders {
size += len(k) + len(v)
}
for k, v := range c.RespHeaders {
size += len(k) + len(v)
}
return size
}
// TokenMetricsEvent represents a token metrics event
type TokenMetricsEvent struct {
Metrics TokenMetrics
@@ -71,10 +103,10 @@ type metricsMonitor struct {
// capture fields
enableCaptures bool
captures map[int]ReqRespCapture // map for O(1) lookup by ID
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total size in bytes
maxCaptureSize int // max bytes for captures
captureSize int // current total compressed size in bytes
maxCaptureSize int // max bytes for captures (uncompressed)
}
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
@@ -84,7 +116,7 @@ func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int)
logger: logger,
maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0,
captures: make(map[int]ReqRespCapture),
captures: make(map[int][]byte),
captureOrder: make([]int, 0),
captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024,
@@ -108,47 +140,82 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
}
// addCapture adds a new capture to the buffer with size-based eviction.
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
if !mp.enableCaptures {
return
}
mp.mu.Lock()
defer mp.mu.Unlock()
captureSize := capture.Size()
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return
}
// Evict oldest (FIFO) until room available
captureSize := len(compressed)
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return
}
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
mp.mu.Lock()
defer mp.mu.Unlock()
// Evict oldest (FIFO) until room available for the compressed data
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists {
mp.captureSize -= evicted.Size()
l := len(evicted)
mp.captureSize -= l
delete(mp.captures, oldestID)
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
}
}
mp.captures[capture.ID] = capture
mp.captures[capture.ID] = compressed
mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
}
// getCaptureByID returns a capture by its ID, or nil if not found.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
mp.mu.RLock()
defer mp.mu.RUnlock()
if capture, exists := mp.captures[id]; exists {
return &capture
data, exists := mp.captures[id]
return data, exists
}
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
// If decompress=false, returns the raw zstd-compressed bytes.
// Returns nil if the capture is not found.
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
mp.mu.RLock()
defer mp.mu.RUnlock()
data, exists := mp.captures[id]
if !exists {
return nil
}
if !decompress {
return data
}
decompressed, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return decompressed
}
// getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
mp.mu.RLock()
@@ -290,8 +357,8 @@ func (mp *metricsMonitor) wrapHandler(
RespHeaders: respHeaders,
RespBody: body,
}
// Only set HasCapture if the capture will actually be stored (not too large)
if capture.Size() <= mp.maxCaptureSize {
compressed, _, err := compressCapture(capture)
if err == nil && len(compressed) <= mp.maxCaptureSize {
tm.HasCapture = true
}
}
+81 -40
View File
@@ -5,6 +5,7 @@ import (
"compress/flate"
"compress/gzip"
"encoding/json"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
@@ -953,28 +954,27 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
})
}
func TestReqRespCapture_Size(t *testing.T) {
t.Run("calculates size correctly", func(t *testing.T) {
func TestReqRespCapture_CompressedSize(t *testing.T) {
t.Run("compressed size is smaller than uncompressed", func(t *testing.T) {
capture := ReqRespCapture{
ID: 1,
ReqPath: "/v1/chat/completions", // 20 bytes
ReqHeaders: map[string]string{
"Content-Type": "application/json", // 12 + 16 = 28
},
ReqBody: []byte("request body"), // 12 bytes
RespHeaders: map[string]string{
"X-Test": "value", // 6 + 5 = 11
},
RespBody: []byte("response body"), // 13 bytes
ReqPath: "/v1/chat/completions",
ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`),
RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`),
}
// Expected: 20 + 12 + 13 + 28 + 11 = 84
assert.Equal(t, 84, capture.Size())
compressed, uncompressed, err := compressCapture(&capture)
assert.NoError(t, err)
assert.Greater(t, uncompressed, 0)
assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed)
})
t.Run("handles empty capture", func(t *testing.T) {
t.Run("empty capture produces compressed output", func(t *testing.T) {
capture := ReqRespCapture{}
assert.Equal(t, 0, capture.Size())
compressed, _, err := compressCapture(&capture)
assert.NoError(t, err)
assert.NotNil(t, compressed)
assert.True(t, len(compressed) > 0)
})
}
@@ -989,7 +989,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
mm.addCapture(capture)
// Should not store capture
assert.Nil(t, mm.getCaptureByID(0))
assert.Nil(t, mm.getCaptureByID(0, false))
})
t.Run("adds capture when enabled", func(t *testing.T) {
@@ -1002,41 +1002,55 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
}
mm.addCapture(capture)
retrieved := mm.getCaptureByID(0)
retrieved := mm.getCaptureByID(0, true)
assert.NotNil(t, retrieved)
assert.Equal(t, 0, retrieved.ID)
assert.Equal(t, []byte("test request"), retrieved.ReqBody)
assert.Equal(t, []byte("test response"), retrieved.RespBody)
var decoded ReqRespCapture
err := json.Unmarshal(retrieved, &decoded)
assert.NoError(t, err)
assert.Equal(t, 0, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
})
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 // Set small limit for test
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
mm.maxCaptureSize = 450
// Add captures that will exceed the limit
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)}
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)}
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)}
// Use random-looking data that doesn't compress well with zstd
rng := rand.New(rand.NewSource(42))
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)}
rng.Read(capture1.ReqBody)
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)}
rng.Read(capture2.ReqBody)
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)}
rng.Read(capture3.ReqBody)
mm.addCapture(capture1)
mm.addCapture(capture2)
// Adding capture3 should evict capture1
mm.addCapture(capture3)
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
retrieved := mm.getCaptureByID(1, true)
assert.NotNil(t, retrieved, "capture 1 should exist")
retrieved = mm.getCaptureByID(2, true)
assert.NotNil(t, retrieved, "capture 2 should exist")
})
t.Run("skips capture larger than max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100
// Add a capture larger than max
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)}
// Use random data that doesn't compress well to create an oversized capture
rng := rand.New(rand.NewSource(99))
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)}
rng.Read(largeCapture.ReqBody)
mm.addCapture(largeCapture)
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
})
}
@@ -1044,21 +1058,44 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
t.Run("returns nil for non-existent ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
assert.Nil(t, mm.getCaptureByID(999))
assert.Nil(t, mm.getCaptureByID(999, false))
})
t.Run("returns capture by ID", func(t *testing.T) {
t.Run("returns decompressed capture by ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
ID: 42,
ReqBody: []byte("test"),
ReqBody: []byte("test request"),
RespBody: []byte("test response"),
}
mm.addCapture(capture)
retrieved := mm.getCaptureByID(42)
retrieved := mm.getCaptureByID(42, true)
assert.NotNil(t, retrieved)
assert.Equal(t, 42, retrieved.ID)
var decoded ReqRespCapture
err := json.Unmarshal(retrieved, &decoded)
assert.NoError(t, err)
assert.Equal(t, 42, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
})
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
ID: 42,
ReqBody: []byte("test request body"),
RespBody: []byte("test response body"),
}
mm.addCapture(capture)
compressed := mm.getCaptureByID(42, false)
assert.NotNil(t, compressed)
// Compressed data should not be valid JSON (it's zstd-compressed)
assert.False(t, gjson.ValidBytes(compressed))
})
}
@@ -1135,9 +1172,13 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics))
metricID := metrics[0].ID
// Check capture was stored with same ID
capture := mm.getCaptureByID(metricID)
assert.NotNil(t, capture)
// Check capture was stored with same ID (decompressed)
captureData := mm.getCaptureByID(metricID, true)
assert.NotNil(t, captureData)
var capture ReqRespCapture
err = json.Unmarshal(captureData, &capture)
assert.NoError(t, err)
assert.Equal(t, metricID, capture.ID)
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Equal(t, []byte(responseBody), capture.RespBody)
@@ -1173,7 +1214,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics))
// But no capture
capture := mm.getCaptureByID(metrics[0].ID)
capture := mm.getCaptureByID(metrics[0].ID, false)
assert.Nil(t, capture)
})
}
+18 -3
View File
@@ -290,11 +290,26 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
return
}
capture := pm.metricsMonitor.getCaptureByID(id)
if capture == nil {
data, exists := pm.metricsMonitor.getCompressedBytes(id)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
return
}
c.JSON(http.StatusOK, capture)
c.Header("Vary", "Accept-Encoding")
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
if hasZstd {
c.Header("Content-Encoding", "zstd")
c.Data(http.StatusOK, "application/json", data)
} else {
decompressed, err := decompressCapture(data)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
return
}
c.Data(http.StatusOK, "application/json", decompressed)
}
}