Compare commits

...

12 Commits

Author SHA1 Message Date
Marcus c79114d40a proxy: fix logger not checking matrix for processes
Fix matrix not being used to search for a logger causing /logs/stream/model_name to return an error
2026-05-01 16:43:20 -07:00
Benson Wong 430166d5eb proxy: fix zero duration for non streaming responses (#723)
Updates #654
2026-04-30 19:51:28 -07:00
Marcus 5b4beaceef fix: ?no-history flag and improve /logs monitoring docs (#721)
- improve logging documentation 
- small tweaks for edge case issues in upstream and log requests
2026-04-30 00:50:36 -07:00
Benson Wong fd3c28ffc5 Refactor Activity Page (#710)
- inference handles to store an activity record for all inference endpoints
- add path, status code, and content type to Activities page
- toggle on/off columns no Activities page 
- add configurable capture level for inference endpoints so large binary blobs are not stored in memory
- store captures in compressed binary format
2026-04-28 20:33:03 -07:00
Quentin Machu a846c4f18c config: remove hard cap on macro length (#718)
Remove macro value limit of 1024 characters
2026-04-28 13:32:54 -07:00
Marcus 5bae33a769 ui-svelte: default theme to user preferred color scheme (#712)
Simple, if not set is localStorage use whatever the user's preferred
color scheme is to start.
2026-04-27 06:44:22 -07:00
Benson Wong 8f4ff01f93 ui-svelte: make it easier to toggle panels in logs view 2026-04-26 22:12:43 -07:00
Benson Wong e8d4384cd2 ui-svelte: support reasoning and reasoning_content (#708)
Support `reasoning` v1/chat/completion delta that vLLM uses.
2026-04-26 13:11:48 -07:00
Benson Wong ce28485be2 ui-svelte: add prompt processing histogram (#705)
Activities page shows histograms for prompt processing and token generation times. 

Fix: #691
Fix: #703
2026-04-25 16:13:07 -07:00
Damir 3cd7837b1f fix: support architecture-specific download URLs in install script (#698)
Just a small fix to include proper llama-swap binary when building the
arm64 architecture.
2026-04-23 18:05:33 -07:00
Benson Wong 0b31ccacc1 ui-svelte: fix histogram calculation (#695)
- Fix the histogram calculation to use server provided generation
tokens/second.
- Move histogram to Activities page where it can exist with the rest of
the token metrics

Fixes #681
2026-04-22 23:42:39 -07:00
Bryan Gahagan 5938dbee8f Push unified docker images on scheduled runs (#694)
Fixes #693
2026-04-22 20:46:51 -07:00
35 changed files with 2039 additions and 917 deletions
+2 -11
View File
@@ -19,9 +19,6 @@ jobs:
run-tests: run-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
working-directory: ui-svelte
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -32,11 +29,5 @@ jobs:
cache: 'npm' cache: 'npm'
cache-dependency-path: ui-svelte/package-lock.json cache-dependency-path: ui-svelte/package-lock.json
- name: Install dependencies - name: Run UI tests
run: npm ci run: make test-ui
- name: Type check
run: npm run check
- name: Run tests
run: npm test
+1 -1
View File
@@ -121,7 +121,7 @@ jobs:
docker/unified/build-image.sh --${{ matrix.backend }} docker/unified/build-image.sh --${{ matrix.backend }}
- name: Push to GitHub Container Registry - name: Push to GitHub Container Registry
if: ${{ !env.ACT && inputs.push_to_ghcr == true }} if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
run: | run: |
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}" BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
DATE_TAG=$(date -u +%Y-%m-%d) DATE_TAG=$(date -u +%Y-%m-%d)
+1
View File
@@ -24,6 +24,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`. - Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests. - Use `make test-all` before completing work. This includes long running concurrency tests.
- Use `make test-ui` after making changes to the UI in ui-svelte/
### Commit message example format: ### Commit message example format:
+4 -1
View File
@@ -97,6 +97,9 @@ wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy" @echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
test-ui:
cd ui-svelte && npm ci && npm run check && npm test
# Phony targets # Phony targets
.PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev wol-proxy .PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
.PHONE: linux linux-arm64 linux-amd64 .PHONE: linux linux-arm64 linux-amd64
+11 -2
View File
@@ -20,6 +20,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `v1/chat/completions` - `v1/chat/completions`
- `v1/responses` - `v1/responses`
- `v1/embeddings` - `v1/embeddings`
- `v1/models` - list available models
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- `v1/audio/voices` - `v1/audio/voices`
@@ -39,9 +40,17 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ llama-swap API - ✅ llama-swap API
- `/ui` - web UI - `/ui` - web UI
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring - `POST /api/models/unload` - manually unload all running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `POST /api/models/unload/:model_id` - unload a specific model
- `/logs` - remote log monitoring
- `GET /logs` returns buffered plain text logs.
- If `Accept: text/html` is sent, `/logs` redirects to `/ui/`.
- `GET /logs/stream` keeps the connection open for live log streaming.
- Stream endpoints send buffered history first by default; add `?no-history` to stream only new lines.
- `GET /logs/stream/proxy` streams proxy logs only.
- `GET /logs/stream/upstream` streams upstream process logs only.
- `GET /logs/stream/{model_id}` streams logs for one model (including IDs with slashes, like `author/model`).
- `/health` - just returns "OK" - `/health` - just returns "OK"
- ✅ API Key support - define keys to restrict access to API endpoints - ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable - ✅ Customizable
+10 -2
View File
@@ -38,8 +38,16 @@ if [ "$VERSION" = "latest" ]; then
echo "Latest version: ${VERSION}" echo "Latest version: ${VERSION}"
fi fi
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) echo "FATAL: Unsupported architecture: $ARCH" >&2; exit 1 ;;
esac
# Download and extract # Download and extract
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz" URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_${ARCH}.tar.gz"
echo "=== Downloading llama-swap v${VERSION} ===" echo "=== Downloading llama-swap v${VERSION} ==="
echo "URL: $URL" echo "URL: $URL"
curl -fSL -o /tmp/llama-swap.tar.gz "$URL" curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
@@ -56,4 +64,4 @@ fi
echo "$VERSION" > /install/llama-swap-version echo "$VERSION" > /install/llama-swap-version
echo "=== llama-swap v${VERSION} installed ===" echo "=== llama-swap v${VERSION} installed ==="
ls -la /install/bin/llama-swap ls -la /install/bin/llama-swap
+2
View File
@@ -4,6 +4,7 @@ go 1.26.1
require ( require (
github.com/billziss-gh/golib v0.2.0 github.com/billziss-gh/golib v0.2.0
github.com/fxamacker/cbor/v2 v2.9.1
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5 github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
@@ -36,6 +37,7 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.45.0 // indirect golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect golang.org/x/net v0.47.0 // indirect
+4
View File
@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -77,6 +79,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
+102
View File
@@ -0,0 +1,102 @@
package cache
import (
"errors"
"sync"
)
var (
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
ErrNotFound = errors.New("item not found")
)
type Cache struct {
mu sync.Mutex
items map[int][]byte
order []int
size int
maxSize int
}
func New(maxBytes int) *Cache {
return &Cache{
items: make(map[int][]byte),
order: make([]int, 0),
maxSize: maxBytes,
}
}
func (c *Cache) Add(id int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
dataSize := len(data)
if dataSize > c.maxSize {
return ErrExceedsMaxSize
}
// If key already exists, remove old entry from size and order
if old, exists := c.items[id]; exists {
c.size -= len(old)
c.removeOrder(id)
}
// Evict oldest (FIFO) until room available
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
oldestID := c.order[0]
c.order = c.order[1:]
if evicted, exists := c.items[oldestID]; exists {
c.size -= len(evicted)
delete(c.items, oldestID)
}
}
c.items[id] = data
c.order = append(c.order, id)
c.size += dataSize
return nil
}
func (c *Cache) removeOrder(id int) {
for i, v := range c.order {
if v == id {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
func (c *Cache) Get(id int) ([]byte, error) {
c.mu.Lock()
defer c.mu.Unlock()
data, exists := c.items[id]
if !exists {
return nil, ErrNotFound
}
return data, nil
}
func (c *Cache) Has(id int) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.items[id]
return exists
}
func (c *Cache) Size() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.size
}
func (c *Cache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[int][]byte)
c.order = c.order[:0]
c.size = 0
}
+130
View File
@@ -0,0 +1,130 @@
package cache
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCache_Add(t *testing.T) {
t.Run("adds and retrieves item", func(t *testing.T) {
c := New(1024)
data := []byte("hello")
require.NoError(t, c.Add(1, data))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, data, got)
})
t.Run("returns error for oversized item", func(t *testing.T) {
c := New(10)
err := c.Add(1, make([]byte, 20))
assert.ErrorIs(t, err, ErrExceedsMaxSize)
})
t.Run("evicts oldest items to make room", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, make([]byte, 40)))
require.NoError(t, c.Add(2, make([]byte, 40)))
// Adding item 3 should evict item 1
require.NoError(t, c.Add(3, make([]byte, 40)))
assert.False(t, c.Has(1))
assert.True(t, c.Has(2))
assert.True(t, c.Has(3))
})
t.Run("overwrites existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("old")))
require.NoError(t, c.Add(1, []byte("new")))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, []byte("new"), got)
assert.Equal(t, 3, c.Size())
})
}
func TestCache_Get(t *testing.T) {
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
c := New(100)
_, err := c.Get(99)
assert.ErrorIs(t, err, ErrNotFound)
})
}
func TestCache_Has(t *testing.T) {
t.Run("returns true for existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("data")))
assert.True(t, c.Has(1))
})
t.Run("returns false for missing key", func(t *testing.T) {
c := New(100)
assert.False(t, c.Has(1))
})
}
func TestCache_Size(t *testing.T) {
t.Run("tracks byte usage", func(t *testing.T) {
c := New(1000)
assert.Equal(t, 0, c.Size())
require.NoError(t, c.Add(1, make([]byte, 100)))
assert.Equal(t, 100, c.Size())
require.NoError(t, c.Add(2, make([]byte, 200)))
assert.Equal(t, 300, c.Size())
})
t.Run("updates on eviction", func(t *testing.T) {
c := New(150)
require.NoError(t, c.Add(1, make([]byte, 100)))
require.NoError(t, c.Add(2, make([]byte, 100)))
// Item 1 should be evicted, size = 100
assert.Equal(t, 100, c.Size())
})
}
func TestCache_Clear(t *testing.T) {
t.Run("removes all items and resets size", func(t *testing.T) {
c := New(1000)
require.NoError(t, c.Add(1, []byte("a")))
require.NoError(t, c.Add(2, []byte("b")))
c.Clear()
assert.Equal(t, 0, c.Size())
assert.False(t, c.Has(1))
assert.False(t, c.Has(2))
})
}
func TestCache_Concurrent(t *testing.T) {
t.Run("concurrent operations are safe", func(t *testing.T) {
c := New(10000)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 100; j++ {
key := id*100 + j
_ = c.Add(key, []byte("data"))
_, _ = c.Get(key)
_ = c.Has(key)
_ = c.Size()
}
}(i)
}
wg.Wait()
})
}
-3
View File
@@ -646,9 +646,6 @@ func validateMacro(name string, value any) error {
// Validate that value is a scalar type // Validate that value is a scalar type
switch v := value.(type) { switch v := value.(type) {
case string: case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
// Check for self-reference // Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name) macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) { if strings.Contains(v, macroSlug) {
+1 -1
View File
@@ -6,7 +6,7 @@ const ProcessStateChangeEventID = 0x01
const ChatCompletionStatsEventID = 0x02 const ChatCompletionStatsEventID = 0x02
const ConfigFileChangedEventID = 0x03 const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04 const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05 const ActivityLogEventID = 0x05
const ModelPreloadedEventID = 0x06 const ModelPreloadedEventID = 0x06
const InFlightRequestsEventID = 0x07 const InFlightRequestsEventID = 0x07
+182 -135
View File
@@ -12,9 +12,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/cache"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -42,37 +44,53 @@ var zstdDecPool = &sync.Pool{
}, },
} }
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd. // compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
// Returns compressed bytes and the original JSON byte count for logging. // Returns compressed bytes and the original CBOR byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) { func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
jsonBytes, err := json.Marshal(c) cborBytes, err := cbor.Marshal(c)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err) return nil, 0, fmt.Errorf("marshal capture: %w", err)
} }
enc := zstdEncPool.Get().(*zstd.Encoder) zenc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(enc) defer zstdEncPool.Put(zenc)
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
} }
// decompressCapture decompresses zstd-compressed JSON and returns it. // decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
func decompressCapture(data []byte) ([]byte, error) { func decompressCapture(data []byte) (*ReqRespCapture, error) {
dec := zstdDecPool.Get().(*zstd.Decoder) dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec) defer zstdDecPool.Put(dec)
return dec.DecodeAll(data, nil) cborBytes, err := dec.DecodeAll(data, nil)
if err != nil {
return nil, fmt.Errorf("decompress capture: %w", err)
}
var capture ReqRespCapture
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
return nil, fmt.Errorf("unmarshal capture: %w", err)
}
return &capture, nil
} }
// TokenMetrics represents parsed token statistics from llama-server logs // TokenMetrics holds token usage and performance metrics
type TokenMetrics struct { type TokenMetrics struct {
ID int `json:"id"` CachedTokens int `json:"cache_tokens"`
Timestamp time.Time `json:"timestamp"` InputTokens int `json:"input_tokens"`
Model string `json:"model"` OutputTokens int `json:"output_tokens"`
CachedTokens int `json:"cache_tokens"` PromptPerSecond float64 `json:"prompt_per_second"`
InputTokens int `json:"input_tokens"` TokensPerSecond float64 `json:"tokens_per_second"`
OutputTokens int `json:"output_tokens"` }
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"` // ActivityLogEntry represents parsed token statistics from llama-server logs
DurationMs int `json:"duration_ms"` type ActivityLogEntry struct {
HasCapture bool `json:"has_capture"` ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
ReqPath string `json:"req_path"`
RespContentType string `json:"resp_content_type"`
RespStatusCode int `json:"resp_status_code"`
Tokens TokenMetrics `json:"tokens"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
} }
type ReqRespCapture struct { type ReqRespCapture struct {
@@ -84,48 +102,45 @@ type ReqRespCapture struct {
RespBody []byte `json:"resp_body"` RespBody []byte `json:"resp_body"`
} }
// TokenMetricsEvent represents a token metrics event // ActivityLogEvent represents a token metrics event
type TokenMetricsEvent struct { type ActivityLogEvent struct {
Metrics TokenMetrics Metrics ActivityLogEntry
} }
func (e TokenMetricsEvent) Type() uint32 { func (e ActivityLogEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go return ActivityLogEventID // defined in events.go
} }
// metricsMonitor parses llama-server output for token statistics // metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct { type metricsMonitor struct {
mu sync.RWMutex mu sync.RWMutex
metrics []TokenMetrics metrics []ActivityLogEntry
maxMetrics int maxMetrics int
nextID int nextID int
logger *LogMonitor logger *LogMonitor
// capture fields // capture fields
enableCaptures bool enableCaptures bool
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total compressed size in bytes
maxCaptureSize int // max bytes for captures (uncompressed)
} }
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the // newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
// capture buffer size in megabytes; 0 disables captures. // capture buffer size in megabytes; 0 disables captures.
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor { func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
return &metricsMonitor{ mm := &metricsMonitor{
logger: logger, logger: logger,
maxMetrics: maxMetrics, maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0, enableCaptures: captureBufferMB > 0,
captures: make(map[int][]byte),
captureOrder: make([]int, 0),
captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024,
} }
if captureBufferMB > 0 {
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
}
return mm
} }
// addMetrics adds a new metric to the collection and publishes an event. // queueMetrics adds a new metric to the collection without emitting an event.
// Returns the assigned metric ID. // Returns the assigned metric ID. Call emitMetric after capture setup.
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int { func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
mp.mu.Lock() mp.mu.Lock()
defer mp.mu.Unlock() defer mp.mu.Unlock()
@@ -135,93 +150,75 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
if len(mp.metrics) > mp.maxMetrics { if len(mp.metrics) > mp.maxMetrics {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:] mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
} }
event.Emit(TokenMetricsEvent{Metrics: metric})
return metric.ID return metric.ID
} }
// addCapture adds a new capture to the buffer with size-based eviction. // emitMetric publishes an ActivityLogEvent for the given metric.
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize. func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) { event.Emit(ActivityLogEvent{Metrics: metric})
}
// addCapture compresses and stores a capture in the cache.
// Returns true if the capture was stored, false otherwise.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
if !mp.enableCaptures { if !mp.enableCaptures {
return return false
} }
compressed, uncompressedBytes, err := compressCapture(&capture) compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil { if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err) mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return return false
} }
captureSize := len(compressed) if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
if captureSize > mp.maxCaptureSize { mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize) return false
return
} }
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100 compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
mp.mu.Lock()
defer mp.mu.Unlock()
// Evict oldest (FIFO) until room available for the compressed data
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists {
l := len(evicted)
mp.captureSize -= l
delete(mp.captures, oldestID)
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
}
}
mp.captures[capture.ID] = compressed
mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio) mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
return true
} }
// getCompressedBytes returns the raw compressed bytes for a capture by ID. // getCompressedBytes returns the raw compressed bytes for a capture by ID.
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) { func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
mp.mu.RLock() if mp.captureCache == nil {
defer mp.mu.RUnlock() return nil, false
}
data, exists := mp.captures[id] data, err := mp.captureCache.Get(id)
return data, exists if err != nil {
return nil, false
}
return data, true
} }
// getCaptureByID returns decompressed capture bytes if found and decompress=true. // getCaptureByID decompresses and unmarshals a capture by ID.
// If decompress=false, returns the raw zstd-compressed bytes. // Returns nil if the capture is not found or decompression fails.
// Returns nil if the capture is not found. func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte { if mp.captureCache == nil {
mp.mu.RLock() return nil
defer mp.mu.RUnlock() }
data, exists := mp.getCompressedBytes(id)
data, exists := mp.captures[id]
if !exists { if !exists {
return nil return nil
} }
if !decompress { capture, err := decompressCapture(data)
return data
}
decompressed, err := decompressCapture(data)
if err != nil { if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err) mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil return nil
} }
return decompressed return capture
} }
// getMetrics returns a copy of the current metrics // getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics { func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
result := make([]TokenMetrics, len(mp.metrics)) result := make([]ActivityLogEntry, len(mp.metrics))
copy(result, mp.metrics) copy(result, mp.metrics)
return result return result
} }
@@ -230,22 +227,52 @@ func (mp *metricsMonitor) getMetrics() []TokenMetrics {
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) { func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
if mp.captureCache == nil {
return json.Marshal(mp.metrics)
}
// Make a copy with up-to-date has_capture from cache
result := make([]ActivityLogEntry, len(mp.metrics))
for i, m := range mp.metrics {
m.HasCapture = mp.captureCache.Has(m.ID)
result[i] = m
}
return json.Marshal(result)
} }
// wrapHandler wraps the proxy handler to extract token metrics // Capture field flags for controlling what is saved in ReqRespCapture.
type captureFields uint
const (
captureNone captureFields = 1 << iota
captureReqHeaders
captureReqBody
captureRespHeaders
captureRespBody
)
const (
captureReqAll = captureReqHeaders | captureReqBody
captureRespAll = captureRespHeaders | captureRespBody
captureAll = captureReqAll | captureRespAll
)
// wrapHandler wraps the proxy handler to extract token metrics.
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
// if wrapHandler returns an error it is safe to assume that no // if wrapHandler returns an error it is safe to assume that no
// data was sent to the client // data was sent to the client
func (mp *metricsMonitor) wrapHandler( func (mp *metricsMonitor) wrapHandler(
modelID string, modelID string,
writer gin.ResponseWriter, writer gin.ResponseWriter,
request *http.Request, request *http.Request,
captureFields captureFields,
next func(modelID string, w http.ResponseWriter, r *http.Request) error, next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error { ) error {
// Capture request body and headers if captures enabled // Capture request body and headers if captures enabled
var reqBody []byte var reqBody []byte
var reqHeaders map[string]string var reqHeaders map[string]string
if mp.enableCaptures { if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
if request.Body != nil { if request.Body != nil {
var err error var err error
reqBody, err = io.ReadAll(request.Body) reqBody, err = io.ReadAll(request.Body)
@@ -255,6 +282,8 @@ func (mp *metricsMonitor) wrapHandler(
request.Body.Close() request.Body.Close()
request.Body = io.NopCloser(bytes.NewBuffer(reqBody)) request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
} }
}
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
reqHeaders = make(map[string]string) reqHeaders = make(map[string]string)
for key, values := range request.Header { for key, values := range request.Header {
if len(values) > 0 { if len(values) > 0 {
@@ -278,22 +307,28 @@ func (mp *metricsMonitor) wrapHandler(
// after this point we have to assume that data was sent to the client // after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients // and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK { // Initialize default metrics - recorded for every request
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path) tm := ActivityLogEntry{
return nil Timestamp: time.Now(),
Model: modelID,
ReqPath: request.URL.Path,
RespContentType: recorder.Header().Get("Content-Type"),
RespStatusCode: recorder.Status(),
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
} }
// Initialize default metrics - these will always be recorded if recorder.Status() != http.StatusOK {
tm := TokenMetrics{ mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
Timestamp: time.Now(), tm.ID = mp.queueMetrics(tm)
Model: modelID, mp.emitMetric(tm)
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), return nil
} }
body := recorder.body.Bytes() body := recorder.body.Bytes()
if len(body) == 0 { if len(body) == 0 {
mp.logger.Warn("metrics: empty body, recording minimal metrics") mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm) tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil return nil
} }
@@ -303,7 +338,8 @@ func (mp *metricsMonitor) wrapHandler(
body, err = decompressBody(body, encoding) body, err = decompressBody(body, encoding)
if err != nil { if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm) tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil return nil
} }
} }
@@ -311,7 +347,8 @@ func (mp *metricsMonitor) wrapHandler(
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else { } else {
tm = parsed tm.Tokens = parsed.Tokens
tm.DurationMs = parsed.DurationMs
} }
} else { } else {
if gjson.ValidBytes(body) { if gjson.ValidBytes(body) {
@@ -331,7 +368,8 @@ func (mp *metricsMonitor) wrapHandler(
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else { } else {
tm = parsedMetrics tm.Tokens = parsedMetrics.Tokens
tm.DurationMs = parsedMetrics.DurationMs
} }
} }
} else { } else {
@@ -342,39 +380,50 @@ func (mp *metricsMonitor) wrapHandler(
// Build capture if enabled and determine if it will be stored // Build capture if enabled and determine if it will be stored
var capture *ReqRespCapture var capture *ReqRespCapture
if mp.enableCaptures { if mp.enableCaptures {
respHeaders := make(map[string]string) var respHeaders map[string]string
for key, values := range recorder.Header() { var respBody []byte
if len(values) > 0 { if (captureFields & captureRespHeaders) != 0 {
respHeaders[key] = values[0] respHeaders = make(map[string]string)
for key, values := range recorder.Header() {
if len(values) > 0 {
respHeaders[key] = values[0]
}
} }
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
}
if (captureFields & captureRespBody) != 0 {
respBody = body
} }
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
capture = &ReqRespCapture{ capture = &ReqRespCapture{
ReqPath: request.URL.Path, ReqPath: request.URL.Path,
ReqHeaders: reqHeaders, ReqHeaders: reqHeaders,
ReqBody: reqBody, ReqBody: reqBody,
RespHeaders: respHeaders, RespHeaders: respHeaders,
RespBody: body, RespBody: respBody,
}
compressed, _, err := compressCapture(capture)
if err == nil && len(compressed) <= mp.maxCaptureSize {
tm.HasCapture = true
} }
} }
metricID := mp.addMetrics(tm) metricID := mp.queueMetrics(tm)
tm.ID = metricID
// Store capture if enabled // Store capture if enabled
if capture != nil { if capture != nil {
capture.ID = metricID capture.ID = metricID
mp.addCapture(*capture) if mp.addCapture(*capture) {
tm.HasCapture = true
mp.mu.Lock()
mp.metrics[len(mp.metrics)-1].HasCapture = true
mp.mu.Unlock()
}
} }
mp.emitMetric(tm)
return nil return nil
} }
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) { func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
// Iterate **backwards** through the body looking for the data payload with // Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split. // usage data. This avoids allocating a slice of all lines via bytes.Split.
@@ -428,10 +477,10 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
} }
} }
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream") return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
} }
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) { func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
wallDurationMs := int(time.Since(start).Milliseconds()) wallDurationMs := int(time.Since(start).Milliseconds())
// default values // default values
@@ -481,15 +530,17 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
} }
} }
return TokenMetrics{ return ActivityLogEntry{
Timestamp: time.Now(), Timestamp: time.Now(),
Model: modelID, Model: modelID,
CachedTokens: cachedTokens, Tokens: TokenMetrics{
InputTokens: inputTokens, CachedTokens: cachedTokens,
OutputTokens: outputTokens, InputTokens: inputTokens,
PromptPerSecond: promptPerSecond, OutputTokens: outputTokens,
TokensPerSecond: tokensPerSecond, PromptPerSecond: promptPerSecond,
DurationMs: durationMs, TokensPerSecond: tokensPerSecond,
},
DurationMs: durationMs,
}, nil }, nil
} }
@@ -527,15 +578,11 @@ func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
ResponseWriter: w, ResponseWriter: w,
body: bodyBuffer, body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer), tee: io.MultiWriter(w, bodyBuffer),
start: time.Now(),
} }
} }
func (w *responseBodyCopier) Write(b []byte) (int, error) { func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
}
// Single write operation that writes to both the response and buffer
return w.tee.Write(b) return w.tee.Write(b)
} }
+344 -175
View File
@@ -12,8 +12,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/cache"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -22,27 +24,29 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) { t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
InputTokens: 100, Tokens: TokenMetrics{
OutputTokens: 50, InputTokens: 100,
OutputTokens: 50,
},
} }
mm.addMetrics(metric) mm.queueMetrics(metric)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID) assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("increments ID for each metric", func(t *testing.T) { t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"}) mm.queueMetrics(ActivityLogEntry{Model: "model"})
} }
metrics := mm.getMetrics() metrics := mm.getMetrics()
@@ -57,9 +61,11 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
// Add 5 metrics // Add 5 metrics
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "model", Model: "model",
InputTokens: i, Tokens: TokenMetrics{
InputTokens: i,
},
}) })
} }
@@ -72,29 +78,32 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
assert.Equal(t, 4, metrics[2].ID) assert.Equal(t, 4, metrics[2].ID)
}) })
t.Run("emits TokenMetricsEvent", func(t *testing.T) { t.Run("emits ActivityLogEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
receivedEvent := make(chan TokenMetricsEvent, 1) receivedEvent := make(chan ActivityLogEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) { cancel := event.On(func(e ActivityLogEvent) {
receivedEvent <- e receivedEvent <- e
}) })
defer cancel() defer cancel()
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
InputTokens: 100, Tokens: TokenMetrics{
OutputTokens: 50, InputTokens: 100,
OutputTokens: 50,
},
} }
mm.addMetrics(metric) mm.queueMetrics(metric)
mm.emitMetric(metric)
select { select {
case evt := <-receivedEvent: case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID) assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model) assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens) assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens) assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens)
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event") t.Fatal("timeout waiting for event")
} }
@@ -111,8 +120,8 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns copy of metrics", func(t *testing.T) { t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{Model: "model1"}) mm.queueMetrics(ActivityLogEntry{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"}) mm.queueMetrics(ActivityLogEntry{Model: "model2"})
metrics1 := mm.getMetrics() metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics() metrics2 := mm.getMetrics()
@@ -135,7 +144,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, jsonData) assert.NotNil(t, jsonData)
var metrics []TokenMetrics var metrics []ActivityLogEntry
err = json.Unmarshal(jsonData, &metrics) err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 0, len(metrics))
@@ -143,23 +152,27 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON with metrics", func(t *testing.T) { t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "model1", Model: "model1",
InputTokens: 100, Tokens: TokenMetrics{
OutputTokens: 50, InputTokens: 100,
TokensPerSecond: 25.5, OutputTokens: 50,
TokensPerSecond: 25.5,
},
}) })
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "model2", Model: "model2",
InputTokens: 200, Tokens: TokenMetrics{
OutputTokens: 100, InputTokens: 200,
TokensPerSecond: 30.0, OutputTokens: 100,
TokensPerSecond: 30.0,
},
}) })
jsonData, err := mm.getMetricsJSON() jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err) assert.NoError(t, err)
var metrics []TokenMetrics var metrics []ActivityLogEntry
err = json.Unmarshal(jsonData, &metrics) err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(metrics)) assert.Equal(t, 2, len(metrics))
@@ -190,14 +203,14 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("successful request with timings data", func(t *testing.T) { t.Run("successful request with timings data", func(t *testing.T) {
@@ -226,17 +239,17 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens) assert.Equal(t, 20, metrics[0].Tokens.CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond) assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond) assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500 assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
}) })
@@ -265,18 +278,18 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence // When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens) assert.Equal(t, 10, metrics[0].Tokens.InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens) assert.Equal(t, 20, metrics[0].Tokens.OutputTokens)
}) })
t.Run("non-OK status code does not record metrics", func(t *testing.T) { t.Run("non-OK status code records partial metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
@@ -289,11 +302,16 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, "/test", metrics[0].ReqPath)
assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("empty response body records minimal metrics", func(t *testing.T) { t.Run("empty response body records minimal metrics", func(t *testing.T) {
@@ -308,14 +326,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("invalid JSON records minimal metrics", func(t *testing.T) { t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
@@ -332,14 +350,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("next handler error is propagated", func(t *testing.T) { t.Run("next handler error is propagated", func(t *testing.T) {
@@ -354,7 +372,7 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.Equal(t, expectedErr, err) assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
@@ -377,14 +395,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("infill request extracts timings from last array element", func(t *testing.T) { t.Run("infill request extracts timings from last array element", func(t *testing.T) {
@@ -416,17 +434,17 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 150, metrics[0].InputTokens) assert.Equal(t, 150, metrics[0].Tokens.InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens) assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
assert.Equal(t, 30, metrics[0].CachedTokens) assert.Equal(t, 30, metrics[0].Tokens.CachedTokens)
assert.Equal(t, 200.5, metrics[0].PromptPerSecond) assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond)
assert.Equal(t, 35.5, metrics[0].TokensPerSecond) assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond)
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800 assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
}) })
@@ -446,14 +464,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
} }
@@ -472,15 +490,11 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
assert.Equal(t, string(testData), rec.Body.String()) assert.Equal(t, string(testData), rec.Body.String())
}) })
t.Run("sets start time on first write", func(t *testing.T) { t.Run("sets start time on creation", func(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer) copier := newBodyCopier(ginCtx.Writer)
assert.True(t, copier.StartTime().IsZero())
copier.Write([]byte("test"))
assert.False(t, copier.StartTime().IsZero()) assert.False(t, copier.StartTime().IsZero())
}) })
@@ -507,7 +521,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
} }
func TestMetricsMonitor_Concurrent(t *testing.T) { func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) { t.Run("concurrent queueMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000, 0) mm := newMetricsMonitor(testLogger, 1000, 0)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -519,10 +533,12 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ { for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "test-model", Model: "test-model",
InputTokens: id*1000 + j, Tokens: TokenMetrics{
OutputTokens: j, InputTokens: id*1000 + j,
OutputTokens: j,
},
}) })
} }
}(i) }(i)
@@ -542,7 +558,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
// Writer goroutine // Writer goroutine
go func() { go func() {
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"}) mm.queueMetrics(ActivityLogEntry{Model: "test-model"})
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
done <- true done <- true
@@ -586,10 +602,10 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
metrics, err := parseMetrics("test-model", start, usage, timings) metrics, err := parseMetrics("test-model", start, usage, timings)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 5, metrics.InputTokens) assert.Equal(t, 5, metrics.Tokens.InputTokens)
assert.Equal(t, 1, metrics.OutputTokens) assert.Equal(t, 1, metrics.Tokens.OutputTokens)
assert.Equal(t, 10.0, metrics.PromptPerSecond) assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond)
assert.Equal(t, 2.0, metrics.TokensPerSecond) assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond)
assert.GreaterOrEqual(t, metrics.DurationMs, 5000) assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
}) })
@@ -623,14 +639,14 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values // Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("handles missing cache_n in timings", func(t *testing.T) { t.Run("handles missing cache_n in timings", func(t *testing.T) {
@@ -658,12 +674,12 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present
}) })
} }
@@ -693,13 +709,13 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) { t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
@@ -722,14 +738,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("v1/responses format with nested response.usage", func(t *testing.T) { t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
@@ -751,14 +767,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 17, metrics[0].InputTokens) assert.Equal(t, 17, metrics[0].Tokens.InputTokens)
assert.Equal(t, 23, metrics[0].OutputTokens) assert.Equal(t, 23, metrics[0].Tokens.OutputTokens)
}) })
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) { t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
@@ -777,14 +793,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
} }
@@ -792,20 +808,22 @@ data: [DONE]
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) { func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000, 0) mm := newMetricsMonitor(testLogger, 1000, 0)
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
CachedTokens: 100, Tokens: TokenMetrics{
InputTokens: 500, CachedTokens: 100,
OutputTokens: 250, InputTokens: 500,
PromptPerSecond: 1200.5, OutputTokens: 250,
TokensPerSecond: 45.8, PromptPerSecond: 1200.5,
DurationMs: 5000, TokensPerSecond: 45.8,
Timestamp: time.Now(), },
DurationMs: 5000,
Timestamp: time.Now(),
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mm.addMetrics(metric) mm.queueMetrics(metric)
} }
} }
@@ -813,20 +831,22 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently // Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100, 0) mm := newMetricsMonitor(testLogger, 100, 0)
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
CachedTokens: 100, Tokens: TokenMetrics{
InputTokens: 500, CachedTokens: 100,
OutputTokens: 250, InputTokens: 500,
PromptPerSecond: 1200.5, OutputTokens: 250,
TokensPerSecond: 45.8, PromptPerSecond: 1200.5,
DurationMs: 5000, TokensPerSecond: 45.8,
Timestamp: time.Now(), },
DurationMs: 5000,
Timestamp: time.Now(),
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mm.addMetrics(metric) mm.queueMetrics(metric)
} }
} }
@@ -855,14 +875,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("deflate encoded response", func(t *testing.T) { t.Run("deflate encoded response", func(t *testing.T) {
@@ -889,14 +909,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 200, metrics[0].InputTokens) assert.Equal(t, 200, metrics[0].Tokens.InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens) assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
}) })
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) { t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
@@ -917,14 +937,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) // Should not return error, just log warning assert.NoError(t, err) // Should not return error, just log warning
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) { t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
@@ -944,13 +964,13 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, 300, metrics[0].InputTokens) assert.Equal(t, 300, metrics[0].Tokens.InputTokens)
assert.Equal(t, 100, metrics[0].OutputTokens) assert.Equal(t, 100, metrics[0].Tokens.OutputTokens)
}) })
} }
@@ -989,7 +1009,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
mm.addCapture(capture) mm.addCapture(capture)
// Should not store capture // Should not store capture
assert.Nil(t, mm.getCaptureByID(0, false)) assert.Nil(t, mm.getCaptureByID(0))
}) })
t.Run("adds capture when enabled", func(t *testing.T) { t.Run("adds capture when enabled", func(t *testing.T) {
@@ -1002,22 +1022,18 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
} }
mm.addCapture(capture) mm.addCapture(capture)
retrieved := mm.getCaptureByID(0, true) captured := mm.getCaptureByID(0)
assert.NotNil(t, retrieved) assert.NotNil(t, captured)
assert.Equal(t, 0, captured.ID)
var decoded ReqRespCapture assert.Equal(t, []byte("test request"), captured.ReqBody)
err := json.Unmarshal(retrieved, &decoded) assert.Equal(t, []byte("test response"), captured.RespBody)
assert.NoError(t, err)
assert.Equal(t, 0, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
}) })
t.Run("evicts oldest when exceeding max size", func(t *testing.T) { t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes. // Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit. // 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
mm.maxCaptureSize = 450 mm.captureCache = cache.New(450)
// Use random-looking data that doesn't compress well with zstd // Use random-looking data that doesn't compress well with zstd
rng := rand.New(rand.NewSource(42)) rng := rand.New(rand.NewSource(42))
@@ -1033,16 +1049,14 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
// Adding capture3 should evict capture1 // Adding capture3 should evict capture1
mm.addCapture(capture3) mm.addCapture(capture3)
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted") assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
retrieved := mm.getCaptureByID(1, true) assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
assert.NotNil(t, retrieved, "capture 1 should exist") assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
retrieved = mm.getCaptureByID(2, true)
assert.NotNil(t, retrieved, "capture 2 should exist")
}) })
t.Run("skips capture larger than max size", func(t *testing.T) { t.Run("skips capture larger than max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 mm.captureCache = cache.New(100)
// Use random data that doesn't compress well to create an oversized capture // Use random data that doesn't compress well to create an oversized capture
rng := rand.New(rand.NewSource(99)) rng := rand.New(rand.NewSource(99))
@@ -1050,7 +1064,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
rng.Read(largeCapture.ReqBody) rng.Read(largeCapture.ReqBody)
mm.addCapture(largeCapture) mm.addCapture(largeCapture)
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored") assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
}) })
} }
@@ -1058,7 +1072,7 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
t.Run("returns nil for non-existent ID", func(t *testing.T) { t.Run("returns nil for non-existent ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
assert.Nil(t, mm.getCaptureByID(999, false)) assert.Nil(t, mm.getCaptureByID(999))
}) })
t.Run("returns decompressed capture by ID", func(t *testing.T) { t.Run("returns decompressed capture by ID", func(t *testing.T) {
@@ -1071,18 +1085,14 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
} }
mm.addCapture(capture) mm.addCapture(capture)
retrieved := mm.getCaptureByID(42, true) captured := mm.getCaptureByID(42)
assert.NotNil(t, retrieved) assert.NotNil(t, captured)
assert.Equal(t, 42, captured.ID)
var decoded ReqRespCapture assert.Equal(t, []byte("test request"), captured.ReqBody)
err := json.Unmarshal(retrieved, &decoded) assert.Equal(t, []byte("test response"), captured.RespBody)
assert.NoError(t, err)
assert.Equal(t, 42, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
}) })
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) { t.Run("stores data as compressed bytes", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{ capture := ReqRespCapture{
@@ -1092,10 +1102,12 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
} }
mm.addCapture(capture) mm.addCapture(capture)
compressed := mm.getCaptureByID(42, false) compressed, exists := mm.getCompressedBytes(42)
assert.True(t, exists)
assert.NotNil(t, compressed) assert.NotNil(t, compressed)
// Compressed data should not be valid JSON (it's zstd-compressed) // Compressed data should not be valid CBOR (it's zstd-compressed)
assert.False(t, gjson.ValidBytes(compressed)) var decoded ReqRespCapture
assert.Error(t, cbor.Unmarshal(compressed, &decoded))
}) })
} }
@@ -1164,7 +1176,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
// Check metric was recorded // Check metric was recorded
@@ -1173,12 +1185,8 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
metricID := metrics[0].ID metricID := metrics[0].ID
// Check capture was stored with same ID (decompressed) // Check capture was stored with same ID (decompressed)
captureData := mm.getCaptureByID(metricID, true) capture := mm.getCaptureByID(metricID)
assert.NotNil(t, captureData) assert.NotNil(t, capture)
var capture ReqRespCapture
err = json.Unmarshal(captureData, &capture)
assert.NoError(t, err)
assert.Equal(t, metricID, capture.ID) assert.Equal(t, metricID, capture.ID)
assert.Equal(t, []byte(requestBody), capture.ReqBody) assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Equal(t, []byte(responseBody), capture.RespBody) assert.Equal(t, []byte(responseBody), capture.RespBody)
@@ -1206,7 +1214,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
// Metrics should still be recorded // Metrics should still be recorded
@@ -1214,7 +1222,168 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
// But no capture // But no capture
capture := mm.getCaptureByID(metrics[0].ID, false) assert.Nil(t, mm.getCaptureByID(metrics[0].ID))
assert.Nil(t, capture) })
}
func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) {
requestBody := `{"model": "test"}`
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Custom", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
t.Run("only request headers", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("only request body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("only response headers", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
assert.Nil(t, capture.RespBody)
})
t.Run("only response body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
t.Run("captureReqAll", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("captureRespAll", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
t.Run("no flags", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("mixed flags req headers and resp body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Equal(t, []byte(responseBody), capture.RespBody)
}) })
} }
+325 -275
View File
@@ -332,41 +332,77 @@ func (pm *ProxyManager) setupGinEngine() {
// Set up routes using the Gin engine // Set up routes using the Gin engine
// Protected routes use pm.apiKeyAuth() middleware // Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) llmHandler := pm.mkProxyJSONHandler(captureAll)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support anthropic count_tokens API (Also added in the above PR) // Support anthropic count_tokens API (Also added in the above PR)
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support embeddings and reranking // Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /reranking endpoint + aliases // llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /infill endpoint for code infilling // llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /completion endpoint // llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST(
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) "/v1/audio/speech",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST(
"/v1/audio/voices",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll),
)
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST(
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler) "/v1/audio/transcriptions",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody),
)
pm.ginEngine.POST(
"/v1/images/generations",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST(
"/v1/images/edits",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders),
)
// sd.cpp /sdapi/v1 endpoints // sd.cpp /sdapi/v1 endpoints
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/sdapi/v1/txt2img",
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST("/sdapi/v1/img2img",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders),
)
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
@@ -647,7 +683,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath) searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
if !modelFound { if !modelFound {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") pm.sendErrorResponse(c, http.StatusNotFound, "model not found")
return return
} }
@@ -686,7 +722,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request // attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" { if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil { if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath) pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return return
@@ -700,280 +736,294 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
} }
} }
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body) return func(c *gin.Context) {
if err != nil { bodyBytes, err := io.ReadAll(c.Request.Body)
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") if err != nil {
return pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
} return
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
modelID, found := pm.config.RealModelName(requestedModel)
if found {
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
} }
// issue #69 allow custom model names to be sent to upstream requestedModel := gjson.GetBytes(bodyBytes, "model").String()
useModelName := pm.config.Models[modelID].UseModelName if requestedModel == "" {
if useModelName != "" { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) return
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
} }
// issue #174 strip parameters from the JSON body // Look for a matching local model first
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams() var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error()) modelID, found := pm.config.RealModelName(requestedModel)
} else { if found {
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
// issue #453 set/override parameters in the JSON body
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams { for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param) pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
}
}
// mkPostFormHandler creates a POST form handler for inference backends
// with a custom captureFields to filter out large binary requests or responses.
func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) {
return func(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return return
} }
} }
} }
// issue #453 set/override parameters in the JSON body // Copy all files from the original request
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams() for key, fileHeaders := range c.Request.MultipartForm.File {
for _, key := range setParamKeys { for _, fileHeader := range fileHeaders {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key) formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) if err != nil {
if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) return
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
} }
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request file, err := fileHeader.Open()
for key, fileHeaders := range c.Request.MultipartForm.File { if err != nil {
for _, fileHeader := range fileHeaders { pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) return
if err != nil { }
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open() if _, err = io.Copy(formFile, file); err != nil {
if err != nil { file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return return
} }
if _, err = io.Copy(formFile, file); err != nil {
file.Close() file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") }
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if pm.metricsMonitor != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return return
} }
file.Close()
} }
} }
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
} }
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) { func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
+10 -20
View File
@@ -158,7 +158,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
} }
} }
sendMetrics := func(metrics []TokenMetrics) { sendMetrics := func(metrics []ActivityLogEntry) {
jsonData, err := json.Marshal(metrics) jsonData, err := json.Marshal(metrics)
if err == nil { if err == nil {
select { select {
@@ -205,8 +205,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
/** /**
* Send Metrics data * Send Metrics data
*/ */
defer event.On(func(e TokenMetricsEvent) { defer event.On(func(e ActivityLogEvent) {
sendMetrics([]TokenMetrics{e.Metrics}) sendMetrics([]ActivityLogEntry{e.Metrics})
})() })()
/** /**
@@ -290,26 +290,16 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
return return
} }
data, exists := pm.metricsMonitor.getCompressedBytes(id) capture := pm.metricsMonitor.getCaptureByID(id)
if !exists { if capture == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"}) c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
return return
} }
c.Header("Vary", "Accept-Encoding") jsonBytes, err := json.Marshal(capture)
if err != nil {
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd") return
if hasZstd {
c.Header("Content-Encoding", "zstd")
c.Data(http.StatusOK, "application/json", data)
} else {
decompressed, err := decompressCapture(data)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
return
}
c.Data(http.StatusOK, "application/json", decompressed)
} }
c.Data(http.StatusOK, "application/json", jsonBytes)
} }
+13
View File
@@ -32,6 +32,13 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/") logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
// Handle case where query string might be included in the parameter
// (can happen with catch-all routes on some versions/setups)
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
logMonitorId = logMonitorId[:idx]
}
logger, err := pm.getLogger(logMonitorId) logger, err := pm.getLogger(logMonitorId)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, err.Error()) c.String(http.StatusBadRequest, err.Error())
@@ -100,6 +107,12 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
return process.Logger(), nil return process.Logger(), nil
} }
} }
// also check the matrix when processGroups doesn't contain the model
if pm.matrix != nil {
if process, found := pm.matrix.GetProcess(name); found {
return process.Logger(), nil
}
}
} }
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID") return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
+173
View File
@@ -0,0 +1,173 @@
package proxy
import (
"context"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogMonitorIdQueryParameterStripping(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "upstream without query param",
input: "upstream",
expected: "upstream",
},
{
name: "upstream with query param",
input: "upstream?no-history",
expected: "upstream",
},
{
name: "proxy with multiple query params",
input: "proxy?no-history&foo=bar",
expected: "proxy",
},
{
name: "model with slash and query param",
input: "author/model?no-history",
expected: "author/model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the query parameter stripping logic
logMonitorId := tt.input
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
logMonitorId = logMonitorId[:idx]
}
if logMonitorId != tt.expected {
t.Errorf("Query parameter stripping failed: got %q, want %q", logMonitorId, tt.expected)
}
})
}
}
// TestProxyManager_GetLogger_ProcessGroups verifies getLogger resolves the
// well-known "proxy"/"upstream" loggers and a model ID managed by processGroups.
func TestProxyManager_GetLogger_ProcessGroups(t *testing.T) {
cfg := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
`)
pm := New(cfg)
defer pm.StopProcesses(StopImmediately)
tests := []struct {
id string
wantErr bool
}{
{"proxy", false},
{"upstream", false},
{"model1", false},
{"does-not-exist", true},
}
for _, tt := range tests {
t.Run(tt.id, func(t *testing.T) {
logger, err := pm.getLogger(tt.id)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid logger")
} else {
require.NoError(t, err)
assert.NotNil(t, logger)
}
})
}
}
// TestProxyManager_GetLogger_Matrix verifies that getLogger can resolve a model
// ID when the proxy is configured with a swap matrix (pm.processGroups is empty
// for matrix-managed models).
func TestProxyManager_GetLogger_Matrix(t *testing.T) {
cfg := config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
ExpandedSets: []config.ExpandedSet{
{SetName: "s1", Models: []string{"model1", "model2"}},
},
Matrix: &config.MatrixConfig{},
}
pm := New(cfg)
defer pm.StopProcesses(StopImmediately)
tests := []struct {
id string
wantErr bool
}{
{"proxy", false},
{"upstream", false},
{"model1", false},
{"model2", false},
{"does-not-exist", true},
}
for _, tt := range tests {
t.Run(tt.id, func(t *testing.T) {
logger, err := pm.getLogger(tt.id)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid logger")
} else {
require.NoError(t, err)
assert.NotNil(t, logger)
}
})
}
}
// TestProxyManager_StreamLogs_Matrix verifies that /logs/stream/<modelID>
// returns 200 (not 400) for a model managed by the swap matrix.
func TestProxyManager_StreamLogs_Matrix(t *testing.T) {
cfg := config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"matrix-model": getTestSimpleResponderConfig("matrix-model"),
},
ExpandedSets: []config.ExpandedSet{
{SetName: "s1", Models: []string{"matrix-model"}},
},
Matrix: &config.MatrixConfig{},
}
pm := New(cfg)
defer pm.StopProcesses(StopImmediately)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
req := httptest.NewRequest("GET", "/logs/stream/matrix-model", nil)
req = req.WithContext(ctx)
rec := CreateTestResponseRecorder()
done := make(chan struct{})
go func() {
defer close(done)
pm.ServeHTTP(rec, req)
}()
<-ctx.Done()
<-done
assert.Equal(t, 200, rec.Code)
}
+58
View File
@@ -1721,3 +1721,61 @@ models:
assert.Contains(t, w.Body.String(), "could not find suitable handler") assert.Contains(t, w.Body.String(), "could not find suitable handler")
}) })
} }
func TestProxyManager_AudioTranscriptionCapture(t *testing.T) {
cfg := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
captureBuffer: 5
models:
TheExpectedModel:
cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel
`)
proxy := New(cfg)
defer proxy.StopProcesses(StopWaitForInflightRequest)
injectTestHandlers(proxy, nil)
var b bytes.Buffer
w := multipart.NewWriter(&b)
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte("TheExpectedModel"))
assert.NoError(t, err)
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test audio content"))
assert.NoError(t, err)
w.Close()
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
req.Header.Set("Authorization", "Bearer mysecret")
req.Header.Set("X-Custom-Req", "req-value")
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify capture exists
metrics := proxy.metricsMonitor.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.True(t, metrics[0].HasCapture)
capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID)
assert.NotNil(t, capture)
// Should capture request headers (sensitive ones redacted)
assert.NotEmpty(t, capture.ReqHeaders)
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"])
// Should capture response headers
assert.NotNil(t, capture.RespHeaders)
// Should NOT capture request bodies but get response bodies (text
assert.Nil(t, capture.ReqBody)
assert.NotNil(t, capture.RespBody)
}
+3 -3
View File
@@ -2788,9 +2788,9 @@
} }
}, },
"node_modules/postcss": { "node_modules/postcss": {
"version": "8.5.8", "version": "8.5.12",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==", "integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==",
"dev": true, "dev": true,
"funding": [ "funding": [
{ {
@@ -0,0 +1,97 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import { persistentStore } from "../stores/persistent";
import { calculateHistogramData } from "../lib/histogram";
import TokenHistogram from "./TokenHistogram.svelte";
const nf = new Intl.NumberFormat();
const histogramCollapsed = persistentStore<boolean>("activity-histogram-collapsed", false);
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.tokens.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.tokens.output_tokens, 0);
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.tokens.cache_tokens, 0);
const promptPerSecond = $metrics.filter((m) => m.tokens.prompt_per_second > 0).map((m) => m.tokens.prompt_per_second);
const tokensPerSecond = $metrics.filter((m) => m.tokens.tokens_per_second > 0).map((m) => m.tokens.tokens_per_second);
const promptHistogramData =
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
const genHistogramData =
tokensPerSecond.length > 0 ? calculateHistogramData(tokensPerSecond) : null;
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
totalCacheTokens,
inFlightRequests: $inFlightRequests,
promptHistogramData,
genHistogramData,
};
});
</script>
<div class="card relative p-3">
<button
class="absolute top-2 right-2 w-6 h-6 flex items-center justify-center rounded-full border border-gray-300 dark:border-gray-600 text-gray-400 dark:text-gray-500 hover:text-gray-600 dark:hover:text-gray-300 hover:border-gray-400 dark:hover:border-gray-400 transition-colors"
onclick={() => ($histogramCollapsed = !$histogramCollapsed)}
title={$histogramCollapsed ? "Show histograms" : "Hide histograms"}
>
{#if $histogramCollapsed}
<svg class="w-3.5 h-3.5" viewBox="0 0 16 16" fill="currentColor">
<path d="M4.5 6l3.5 4 3.5-4H4.5z" />
</svg>
{:else}
<svg class="w-3 h-3" viewBox="0 0 16 16" fill="currentColor">
<path d="M3.5 3.5l9 9M12.5 3.5l-9 9" stroke="currentColor" stroke-width="2" stroke-linecap="round" fill="none" />
</svg>
{/if}
</button>
{#if !$histogramCollapsed}
<div class="flex flex-col sm:flex-row gap-6 mb-3">
<div class="w-full sm:w-1/2 min-w-0">
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Prompt Processing</div>
{#if stats.promptHistogramData}
<TokenHistogram
data={stats.promptHistogramData}
unit="prompt tokens/sec"
colorClass="text-amber-500 dark:text-amber-400"
/>
{:else}
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No prompt speed data yet</div>
{/if}
</div>
<div class="w-full sm:w-1/2 min-w-0">
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Token Generation</div>
{#if stats.genHistogramData}
<TokenHistogram data={stats.genHistogramData} unit="tokens/sec" />
{:else}
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No generation speed data yet</div>
{/if}
</div>
</div>
{/if}
<div class="grid grid-cols-4 gap-x-6 gap-y-1 text-sm">
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Requests</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Cached</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Processed</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Generated</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalRequests)}</span> completed,
<span class="font-semibold">{nf.format(stats.inFlightRequests)}</span> waiting
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalCacheTokens)}</span> tokens
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalInputTokens)}</span> tokens
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalOutputTokens)}</span> tokens
</div>
</div>
</div>
@@ -106,6 +106,7 @@
const delta = parsed.choices?.[0]?.delta; const delta = parsed.choices?.[0]?.delta;
if (delta?.content) result.content += delta.content; if (delta?.content) result.content += delta.content;
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content; if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
if (delta?.reasoning) result.reasoning += delta.reasoning;
} catch { } catch {
// skip unparseable lines // skip unparseable lines
} }
+7 -1
View File
@@ -50,7 +50,7 @@
<a <a
href="/" href="/"
use:link use:link
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}" class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold underline underline-offset-4' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
> >
Playground Playground
</a> </a>
@@ -59,6 +59,8 @@
use:link use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/models", $currentRoute)} class:font-semibold={isActive("/models", $currentRoute)}
class:underline={isActive("/models", $currentRoute)}
class:underline-offset-4={isActive("/models", $currentRoute)}
> >
Models Models
</a> </a>
@@ -67,6 +69,8 @@
use:link use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/activity", $currentRoute)} class:font-semibold={isActive("/activity", $currentRoute)}
class:underline={isActive("/activity", $currentRoute)}
class:underline-offset-4={isActive("/activity", $currentRoute)}
> >
Activity Activity
</a> </a>
@@ -75,6 +79,8 @@
use:link use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/logs", $currentRoute)} class:font-semibold={isActive("/logs", $currentRoute)}
class:underline={isActive("/logs", $currentRoute)}
class:underline-offset-4={isActive("/logs", $currentRoute)}
> >
Logs Logs
</a> </a>
-167
View File
@@ -1,167 +0,0 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import TokenHistogram from "./TokenHistogram.svelte";
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
if (totalRequests === 0) {
return {
totalRequests: 0,
totalInputTokens: 0,
totalOutputTokens: 0,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
// Calculate token statistics using output_tokens and duration_ms
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
if (validMetrics.length === 0) {
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
// Calculate tokens/second for each valid metric
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
// Sort for percentile calculation
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
// Create histogram data
const min = Math.min(...tokensPerSecond);
const max = Math.max(...tokensPerSecond);
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
const binSize = (max - min) / binCount;
const bins = Array(binCount).fill(0);
tokensPerSecond.forEach((value) => {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
});
const histogramData: HistogramData = {
bins,
min,
max,
binSize,
p99,
p95,
p50,
};
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: {
p99: p99.toFixed(2),
p95: p95.toFixed(2),
p50: p50.toFixed(2),
},
histogramData,
};
});
const nf = new Intl.NumberFormat();
</script>
<div class="card">
<div class="rounded-lg overflow-hidden border border-card-border-inner">
<table class="min-w-full divide-y divide-card-border-inner">
<thead class="bg-secondary">
<tr>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Processed
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Generated
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Token Stats (tokens/sec)
</th>
</tr>
</thead>
<tbody class="bg-surface divide-y divide-card-border-inner">
<tr class="hover:bg-secondary">
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
<div class="flex flex-col gap-1">
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
<div class="space-y-3">
<div class="grid grid-cols-3 gap-2 items-center">
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p50}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p95}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p99}
</div>
</div>
</div>
{#if stats.histogramData}
<TokenHistogram data={stats.histogramData} />
{/if}
</div>
</td>
</tr>
</tbody>
</table>
</div>
</div>
+41 -25
View File
@@ -1,23 +1,19 @@
<script lang="ts"> <script lang="ts">
interface HistogramData { import type { HistogramData } from "../lib/types";
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
interface Props { let {
data,
unit = "tokens/sec",
colorClass = "text-blue-500 dark:text-blue-400",
}: {
data: HistogramData; data: HistogramData;
} unit?: string;
colorClass?: string;
} = $props();
let { data }: Props = $props(); const height = 250;
const padding = { top: 30, right: 20, bottom: 40, left: 75 };
const height = 120; const viewBoxWidth = 1200;
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
const viewBoxWidth = 600;
const chartWidth = viewBoxWidth - padding.left - padding.right; const chartWidth = viewBoxWidth - padding.left - padding.right;
const chartHeight = height - padding.top - padding.bottom; const chartHeight = height - padding.top - padding.bottom;
@@ -43,6 +39,24 @@
opacity="0.3" opacity="0.3"
/> />
<!-- Y-axis ticks and labels -->
{#each [0, 0.5, 1] as fraction}
{@const tickCount = Math.round(maxCount * fraction)}
{@const tickY = height - padding.bottom - fraction * chartHeight}
<line
x1={padding.left - 8}
y1={tickY}
x2={padding.left}
y2={tickY}
stroke="currentColor"
stroke-width="1"
opacity="0.4"
/>
<text x={padding.left - 10} y={tickY + 10} font-size="26" fill="currentColor" opacity="0.8" text-anchor="end">
{tickCount}
</text>
{/each}
<!-- X-axis --> <!-- X-axis -->
<line <line
x1={padding.left} x1={padding.left}
@@ -69,9 +83,9 @@
height={barHeight} height={barHeight}
fill="currentColor" fill="currentColor"
opacity="0.6" opacity="0.6"
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer" class="{colorClass} hover:opacity-90 transition-opacity cursor-pointer"
/> />
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title> <title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} ${unit}\nCount: ${count}`}</title>
</g> </g>
{/each} {/each}
@@ -113,17 +127,19 @@
/> />
<!-- X-axis labels --> <!-- X-axis labels -->
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start"> <text x={padding.left} y={height - 8} font-size="26" fill="currentColor" opacity="0.8" text-anchor="start">
{data.min.toFixed(1)} {data.min.toFixed(1)}
</text> </text>
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end"> <text
x={viewBoxWidth - padding.right}
y={height - 8}
font-size="26"
fill="currentColor"
opacity="0.8"
text-anchor="end"
>
{data.max.toFixed(1)} {data.max.toFixed(1)}
</text> </text>
<!-- X-axis label -->
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
Tokens/Second Distribution
</text>
</svg> </svg>
</div> </div>
+1 -1
View File
@@ -25,7 +25,7 @@ function parseSSELine(line: string): StreamChunk | null {
const parsed = JSON.parse(data); const parsed = JSON.parse(data);
const delta = parsed.choices?.[0]?.delta; const delta = parsed.choices?.[0]?.delta;
const content = delta?.content || ""; const content = delta?.content || "";
const reasoning_content = delta?.reasoning_content || ""; const reasoning_content = delta?.reasoning_content || delta?.reasoning || "";
if (content || reasoning_content) { if (content || reasoning_content) {
return { content, reasoning_content, done: false }; return { content, reasoning_content, done: false };
+167
View File
@@ -0,0 +1,167 @@
import { describe, it, expect } from "vitest";
import { calculateHistogramData } from "./histogram";
describe("calculateHistogramData", () => {
describe("edge cases", () => {
it("returns null for empty input", () => {
expect(calculateHistogramData([])).toBeNull();
});
it("handles single value", () => {
const result = calculateHistogramData([42]);
expect(result).not.toBeNull();
expect(result!.bins).toEqual([1]);
expect(result!.min).toBe(42);
expect(result!.max).toBe(42);
expect(result!.binSize).toBe(0);
expect(result!.p50).toBe(42);
expect(result!.p95).toBe(42);
expect(result!.p99).toBe(42);
});
it("handles all identical values", () => {
const result = calculateHistogramData([10, 10, 10, 10, 10]);
expect(result).not.toBeNull();
expect(result!.bins).toEqual([5]);
expect(result!.min).toBe(10);
expect(result!.max).toBe(10);
expect(result!.binSize).toBe(0);
});
it("handles two distinct values", () => {
const result = calculateHistogramData([10, 20]);
expect(result).not.toBeNull();
expect(result!.min).toBe(10);
expect(result!.max).toBe(20);
expect(result!.p50).toBe(15);
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(2);
});
});
describe("bin distribution", () => {
it("bins sum to total number of values", () => {
const values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(values.length);
});
it("distributes uniform values across bins", () => {
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.bins.length).toBe(8);
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(100);
});
it("places values in correct bins", () => {
const values = [1, 1, 1, 5, 5, 9, 9, 9];
const result = calculateHistogramData(values, { minBins: 3, maxBins: 3 });
expect(result).not.toBeNull();
expect(result!.bins.length).toBe(3);
expect(result!.bins.reduce((s, b) => s + b, 0)).toBe(8);
});
it("handles skewed distribution", () => {
const values = [1, 1, 1, 1, 1, 100];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(6);
});
});
describe("percentiles", () => {
it("calculates correct p50 for even-length array", () => {
const values = [1, 2, 3, 4];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBe(2.5);
});
it("calculates correct p50 for odd-length array", () => {
const values = [1, 2, 3, 4, 5];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBe(3);
});
it("calculates p99 with interpolation", () => {
const values = Array.from({ length: 100 }, (_, i) => i + 1);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p99).toBeCloseTo(99.01);
});
it("calculates p95 with interpolation", () => {
const values = Array.from({ length: 100 }, (_, i) => i + 1);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p95).toBeCloseTo(95.05);
});
it("percentiles are monotonically increasing", () => {
const values = Array.from({ length: 200 }, () => Math.random() * 100);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBeLessThanOrEqual(result!.p95);
expect(result!.p95).toBeLessThanOrEqual(result!.p99);
});
});
describe("bin count adaptation", () => {
it("uses minimum bins for small datasets", () => {
// n=8: sturges=4, clamped up to minBins=5
const values = Array.from({ length: 8 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(5);
});
it("scales bins with dataset size", () => {
// n=100: sturges=8
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(8);
});
it("caps bins at maximum", () => {
// n=1000: sturges=11, clamped down to maxBins=10
const values = Array.from({ length: 1000 }, (_, i) => i);
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10 });
expect(result!.bins.length).toBe(10);
});
it("respects custom options", () => {
// n=100: sturges=8, within [minBins=5, maxBins=10]
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10 });
expect(result!.bins.length).toBe(8);
});
});
describe("min and max", () => {
it("correctly identifies min and max", () => {
const values = [5, 3, 8, 1, 9, 2];
const result = calculateHistogramData(values);
expect(result!.min).toBe(1);
expect(result!.max).toBe(9);
});
it("handles negative values", () => {
const values = [-10, -5, 0, 5, 10];
const result = calculateHistogramData(values);
expect(result!.min).toBe(-10);
expect(result!.max).toBe(10);
});
it("handles floating point values", () => {
const values = [1.5, 2.7, 3.14, 0.5, 4.99];
const result = calculateHistogramData(values);
expect(result!.min).toBe(0.5);
expect(result!.max).toBe(4.99);
});
});
});
+71
View File
@@ -0,0 +1,71 @@
import type { HistogramData } from "./types";
export interface HistogramOptions {
minBins?: number;
maxBins?: number;
}
const DEFAULT_OPTIONS: HistogramOptions = {
minBins: 5,
maxBins: 20,
};
function percentile(sorted: number[], p: number): number {
if (sorted.length === 0) return 0;
if (sorted.length === 1) return sorted[0];
const rank = (p / 100) * (sorted.length - 1);
const lower = Math.floor(rank);
const upper = Math.ceil(rank);
const fraction = rank - lower;
return sorted[lower] + fraction * (sorted[upper] - sorted[lower]);
}
export function calculateHistogramData(
values: number[],
options: HistogramOptions = DEFAULT_OPTIONS,
): HistogramData | null {
if (values.length === 0) return null;
const sorted = [...values].sort((a, b) => a - b);
const min = sorted[0];
const max = sorted[sorted.length - 1];
const p50 = percentile(sorted, 50);
const p95 = percentile(sorted, 95);
const p99 = percentile(sorted, 99);
if (min === max) {
return {
bins: [values.length],
min,
max,
binSize: 0,
p50,
p95,
p99,
};
}
const { minBins = 5, maxBins = 20 } = options;
const sturges = Math.ceil(Math.log2(values.length)) + 1;
const binCount = Math.min(maxBins, Math.max(minBins, sturges));
const binSize = (max - min) / binCount;
const bins = new Array(binCount).fill(0);
for (const value of values) {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
}
return {
bins,
min,
max,
binSize,
p50,
p95,
p99,
};
}
+21 -4
View File
@@ -12,15 +12,22 @@ export interface Model {
aliases?: string[]; aliases?: string[];
} }
export interface Metrics { export interface TokenMetrics {
id: number;
timestamp: string;
model: string;
cache_tokens: number; cache_tokens: number;
input_tokens: number; input_tokens: number;
output_tokens: number; output_tokens: number;
prompt_per_second: number; prompt_per_second: number;
tokens_per_second: number; tokens_per_second: number;
}
export interface ActivityLogEntry {
id: number;
timestamp: string;
model: string;
req_path: string;
resp_content_type: string;
resp_status_code: number;
tokens: TokenMetrics;
duration_ms: number; duration_ms: number;
has_capture: boolean; has_capture: boolean;
} }
@@ -48,6 +55,16 @@ export interface APIEventEnvelope {
data: string; data: string;
} }
export interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
export interface VersionInfo { export interface VersionInfo {
build_date: string; build_date: string;
commit: string; commit: string;
+215 -39
View File
@@ -1,9 +1,89 @@
<script lang="ts"> <script lang="ts">
import { metrics, getCapture } from "../stores/api"; import { metrics, getCapture } from "../stores/api";
import ActivityStats from "../components/ActivityStats.svelte";
import Tooltip from "../components/Tooltip.svelte"; import Tooltip from "../components/Tooltip.svelte";
import CaptureDialog from "../components/CaptureDialog.svelte"; import CaptureDialog from "../components/CaptureDialog.svelte";
import { persistentStore } from "../stores/persistent";
import { onMount } from "svelte";
import type { ReqRespCapture } from "../lib/types"; import type { ReqRespCapture } from "../lib/types";
type ColumnKey =
| "id"
| "time"
| "model"
| "req_path"
| "resp_status_code"
| "resp_content_type"
| "cached"
| "prompt"
| "generated"
| "prompt_speed"
| "gen_speed"
| "duration"
| "capture";
interface ColumnDef {
key: ColumnKey;
label: string;
defaultVisible: boolean;
}
const columns: ColumnDef[] = [
{ key: "id", label: "ID", defaultVisible: true },
{ key: "time", label: "Time", defaultVisible: true },
{ key: "model", label: "Model", defaultVisible: true },
{ key: "req_path", label: "Path", defaultVisible: false },
{ key: "resp_status_code", label: "Status", defaultVisible: false },
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
{ key: "cached", label: "Cached", defaultVisible: true },
{ key: "prompt", label: "Prompt", defaultVisible: true },
{ key: "generated", label: "Generated", defaultVisible: true },
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
{ key: "duration", label: "Duration", defaultVisible: true },
{ key: "capture", label: "Capture", defaultVisible: true },
];
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
const visibleColumns = persistentStore<ColumnKey[]>(
"activity-columns",
defaultVisibleKeys
);
let columnsMenuOpen = $state(false);
let dropdownContainer: HTMLDivElement | null = null;
onMount(() => {
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape" && columnsMenuOpen) {
columnsMenuOpen = false;
}
}
function handleClick(e: MouseEvent) {
if (columnsMenuOpen && dropdownContainer && !dropdownContainer.contains(e.target as Node)) {
columnsMenuOpen = false;
}
}
document.addEventListener("keydown", handleKeydown);
document.addEventListener("click", handleClick);
return () => {
document.removeEventListener("keydown", handleKeydown);
document.removeEventListener("click", handleClick);
};
});
function toggleColumn(key: ColumnKey) {
const current = $visibleColumns;
if (current.includes(key)) {
if (current.length > 1) {
visibleColumns.set(current.filter((k) => k !== key));
}
} else {
visibleColumns.set([...current, key]);
}
}
function formatSpeed(speed: number): string { function formatSpeed(speed: number): string {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s"; return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
} }
@@ -62,64 +142,160 @@
</script> </script>
<div class="p-2"> <div class="p-2">
<h1 class="text-2xl font-bold">Activity</h1> <div class="mt-4 mb-4">
<ActivityStats />
</div>
{#if $metrics.length === 0} <div class="card overflow-auto relative min-h-[30rem]">
<div class="text-center py-8"> <div class="flex justify-end px-4" bind:this={dropdownContainer}>
<p class="text-gray-600">No metrics data available</p> <div class="relative">
<button
class="w-8 h-8 flex items-center justify-center rounded hover:bg-secondary-hover transition-colors"
onclick={() => (columnsMenuOpen = !columnsMenuOpen)}
title="Select columns"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6V4m0 2a2 2 0 100 4m0-4a2 2 0 110 4m-6 8a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4m6 6v10m6-2a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4"></path>
</svg>
</button>
{#if columnsMenuOpen}
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]">
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10">
Columns
</div>
{#each columns as col (col.key)}
<label
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors"
>
<input
type="checkbox"
checked={$visibleColumns.includes(col.key)}
onchange={() => toggleColumn(col.key)}
class="rounded"
/>
{col.label}
</label>
{/each}
</div>
{/if}
</div>
</div> </div>
{:else}
<div class="card overflow-auto"> <table class="min-w-full divide-y">
<table class="min-w-full divide-y"> <thead class="border-gray-200 dark:border-white/10">
<thead class="border-gray-200 dark:border-white/10"> <tr class="text-left text-xs uppercase tracking-wider">
<tr class="text-left text-xs uppercase tracking-wider"> {#if $visibleColumns.includes("id")}
<th class="px-6 py-3">ID</th> <th class="px-6 py-3">ID</th>
{/if}
{#if $visibleColumns.includes("time")}
<th class="px-6 py-3">Time</th> <th class="px-6 py-3">Time</th>
{/if}
{#if $visibleColumns.includes("model")}
<th class="px-6 py-3">Model</th> <th class="px-6 py-3">Model</th>
{/if}
{#if $visibleColumns.includes("req_path")}
<th class="px-6 py-3">Path</th>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<th class="px-6 py-3">Status</th>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<th class="px-6 py-3">Content-Type</th>
{/if}
{#if $visibleColumns.includes("cached")}
<th class="px-6 py-3"> <th class="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" /> Cached <Tooltip content="prompt tokens from cache" />
</th> </th>
{/if}
{#if $visibleColumns.includes("prompt")}
<th class="px-6 py-3"> <th class="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" /> Prompt <Tooltip content="new prompt tokens processed" />
</th> </th>
{/if}
{#if $visibleColumns.includes("generated")}
<th class="px-6 py-3">Generated</th> <th class="px-6 py-3">Generated</th>
<th class="px-6 py-3">Prompt Processing</th> {/if}
<th class="px-6 py-3">Generation Speed</th> {#if $visibleColumns.includes("prompt_speed")}
<th class="px-6 py-3">Prompt Speed</th>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<th class="px-6 py-3">Gen Speed</th>
{/if}
{#if $visibleColumns.includes("duration")}
<th class="px-6 py-3">Duration</th> <th class="px-6 py-3">Duration</th>
{/if}
{#if $visibleColumns.includes("capture")}
<th class="px-6 py-3">Capture</th> <th class="px-6 py-3">Capture</th>
{/if}
</tr>
</thead>
<tbody class="divide-y">
{#if sortedMetrics.length === 0}
<tr>
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
No activity recorded
</td>
</tr> </tr>
</thead> {:else}
<tbody class="divide-y">
{#each sortedMetrics as metric (metric.id)} {#each sortedMetrics as metric (metric.id)}
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10"> <tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
<td class="px-4 py-4">{metric.id + 1}</td> {#if $visibleColumns.includes("id")}
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td> <td class="px-4 py-4">{metric.id + 1}</td>
<td class="px-6 py-4">{metric.model}</td> {/if}
<td class="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td> {#if $visibleColumns.includes("time")}
<td class="px-6 py-4">{metric.input_tokens.toLocaleString()}</td> <td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
<td class="px-6 py-4">{metric.output_tokens.toLocaleString()}</td> {/if}
<td class="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td> {#if $visibleColumns.includes("model")}
<td class="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td> <td class="px-6 py-4">{metric.model}</td>
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td> {/if}
<td class="px-6 py-4"> {#if $visibleColumns.includes("req_path")}
{#if metric.has_capture} <td class="px-6 py-4">{metric.req_path || "-"}</td>
<button {/if}
onclick={() => viewCapture(metric.id)} {#if $visibleColumns.includes("resp_status_code")}
disabled={loadingCaptureId === metric.id} <td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
class="btn btn--sm" {/if}
> {#if $visibleColumns.includes("resp_content_type")}
{loadingCaptureId === metric.id ? "..." : "View"} <td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
</button> {/if}
{:else} {#if $visibleColumns.includes("cached")}
<span class="text-txtsecondary">-</span> <td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
{/if} {/if}
</td> {#if $visibleColumns.includes("prompt")}
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("generated")}
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
{/if}
{#if $visibleColumns.includes("duration")}
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
{/if}
{#if $visibleColumns.includes("capture")}
<td class="px-6 py-4">
{#if metric.has_capture}
<button
onclick={() => viewCapture(metric.id)}
disabled={loadingCaptureId === metric.id}
class="btn btn--sm"
>
{loadingCaptureId === metric.id ? "..." : "View"}
</button>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
</td>
{/if}
</tr> </tr>
{/each} {/each}
</tbody> {/if}
</table> </tbody>
</div> </table>
{/if} </div>
</div> </div>
<CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} /> <CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} />
+23 -37
View File
@@ -10,49 +10,35 @@
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels"); const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
let direction = $derived<"horizontal" | "vertical">( let direction = $derived<"horizontal" | "vertical">(
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal" $screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal",
); );
function cycleViewMode(): void {
const modes: ViewMode[] = ["panels", "proxy", "upstream"];
const currentIndex = modes.indexOf($viewModeStore);
const nextIndex = (currentIndex + 1) % modes.length;
viewModeStore.set(modes[nextIndex]);
}
function getViewModeIcon(mode: ViewMode): string {
switch (mode) {
case "proxy":
return "P";
case "upstream":
return "U";
case "panels":
return "⊞";
}
}
function getViewModeLabel(mode: ViewMode): string {
switch (mode) {
case "proxy":
return "Proxy";
case "upstream":
return "Upstream";
case "panels":
return "Panels";
}
}
</script> </script>
<div class="flex flex-col h-full w-full gap-2"> <div class="flex flex-col h-full w-full gap-2">
<div class="flex items-center gap-2"> <div class="flex items-center gap-1">
<button <button
onclick={cycleViewMode} onclick={() => viewModeStore.set("panels")}
class="btn flex items-center gap-2 text-sm" class:btn={true}
title="Toggle view mode" class:bg-primary={$viewModeStore === "panels"}
aria-label="Toggle view mode: {getViewModeLabel($viewModeStore)}" class:text-btn-primary-text={$viewModeStore === "panels"}
> >
<span class="font-mono font-bold">{getViewModeIcon($viewModeStore)}</span> Both
<span>{getViewModeLabel($viewModeStore)}</span> </button>
<button
onclick={() => viewModeStore.set("proxy")}
class:btn={true}
class:bg-primary={$viewModeStore === "proxy"}
class:text-btn-primary-text={$viewModeStore === "proxy"}
>
Proxy
</button>
<button
onclick={() => viewModeStore.set("upstream")}
class:btn={true}
class:bg-primary={$viewModeStore === "upstream"}
class:text-btn-primary-text={$viewModeStore === "upstream"}
>
Upstream
</button> </button>
</div> </div>
+1 -9
View File
@@ -2,7 +2,6 @@
import { isNarrow } from "../stores/theme"; import { isNarrow } from "../stores/theme";
import { upstreamLogs } from "../stores/api"; import { upstreamLogs } from "../stores/api";
import ModelsPanel from "../components/ModelsPanel.svelte"; import ModelsPanel from "../components/ModelsPanel.svelte";
import StatsPanel from "../components/StatsPanel.svelte";
import LogPanel from "../components/LogPanel.svelte"; import LogPanel from "../components/LogPanel.svelte";
import ResizablePanels from "../components/ResizablePanels.svelte"; import ResizablePanels from "../components/ResizablePanels.svelte";
@@ -14,13 +13,6 @@
<ModelsPanel /> <ModelsPanel />
{/snippet} {/snippet}
{#snippet rightPanel()} {#snippet rightPanel()}
<div class="flex flex-col h-full space-y-4"> <LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
{#if direction === "horizontal"}
<StatsPanel />
{/if}
<div class="flex-1 min-h-0">
<LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
</div>
</div>
{/snippet} {/snippet}
</ResizablePanels> </ResizablePanels>
+12 -4
View File
@@ -1,5 +1,13 @@
import { writable } from "svelte/store"; import { writable } from "svelte/store";
import type { Model, Metrics, VersionInfo, LogData, APIEventEnvelope, ReqRespCapture, InFlightStats } from "../lib/types"; import type {
Model,
ActivityLogEntry,
VersionInfo,
LogData,
APIEventEnvelope,
ReqRespCapture,
InFlightStats,
} from "../lib/types";
import { connectionState } from "./theme"; import { connectionState } from "./theme";
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
@@ -8,7 +16,7 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
export const models = writable<Model[]>([]); export const models = writable<Model[]>([]);
export const proxyLogs = writable<string>(""); export const proxyLogs = writable<string>("");
export const upstreamLogs = writable<string>(""); export const upstreamLogs = writable<string>("");
export const metrics = writable<Metrics[]>([]); export const metrics = writable<ActivityLogEntry[]>([]);
export const inFlightRequests = writable<number>(0); export const inFlightRequests = writable<number>(0);
export const versionInfo = writable<VersionInfo>({ export const versionInfo = writable<VersionInfo>({
build_date: "unknown", build_date: "unknown",
@@ -62,7 +70,7 @@ export function enableAPIEvents(enabled: boolean): void {
const newModels = JSON.parse(message.data) as Model[]; const newModels = JSON.parse(message.data) as Model[];
// Sort models by name and id // Sort models by name and id
newModels.sort((a, b) => { newModels.sort((a, b) => {
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric : true} ); return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric: true });
}); });
models.set(newModels); models.set(newModels);
break; break;
@@ -82,7 +90,7 @@ export function enableAPIEvents(enabled: boolean): void {
} }
case "metrics": { case "metrics": {
const newMetrics = JSON.parse(message.data) as Metrics[]; const newMetrics = JSON.parse(message.data) as ActivityLogEntry[];
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]); metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
break; break;
} }
+2 -1
View File
@@ -3,7 +3,8 @@ import { persistentStore } from "./persistent";
import type { ScreenWidth } from "../lib/types"; import type { ScreenWidth } from "../lib/types";
// Persistent stores // Persistent stores
export const isDarkMode = persistentStore<boolean>("theme", false); const systemDark = typeof window !== "undefined" && window.matchMedia("(prefers-color-scheme: dark)").matches;
export const isDarkMode = persistentStore<boolean>("theme", systemDark);
export const appTitle = persistentStore<string>("app-title", "llama-swap"); export const appTitle = persistentStore<string>("app-title", "llama-swap");
// Non-persistent stores // Non-persistent stores
+4
View File
@@ -26,6 +26,10 @@ export default defineConfig({
assetsDir: "assets", assetsDir: "assets",
}, },
server: { server: {
// yes very insecure but who's running this thing
// on the public internet for dev?! haha.
host: "0.0.0.0",
allowedHosts: true,
proxy: { proxy: {
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development "/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080", "/logs": "http://localhost:8080",