Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c79114d40a | |||
| 430166d5eb | |||
| 5b4beaceef | |||
| fd3c28ffc5 | |||
| a846c4f18c | |||
| 5bae33a769 | |||
| 8f4ff01f93 |
@@ -20,6 +20,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
- `v1/responses`
|
- `v1/responses`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
|
- `v1/models` - list available models
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
- `v1/audio/voices`
|
- `v1/audio/voices`
|
||||||
@@ -39,9 +40,17 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
- ✅ llama-swap API
|
- ✅ llama-swap API
|
||||||
- `/ui` - web UI
|
- `/ui` - web UI
|
||||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
|
||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
- `/log` - remote log monitoring
|
- `POST /api/models/unload` - manually unload all running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
|
- `POST /api/models/unload/:model_id` - unload a specific model
|
||||||
|
- `/logs` - remote log monitoring
|
||||||
|
- `GET /logs` returns buffered plain text logs.
|
||||||
|
- If `Accept: text/html` is sent, `/logs` redirects to `/ui/`.
|
||||||
|
- `GET /logs/stream` keeps the connection open for live log streaming.
|
||||||
|
- Stream endpoints send buffered history first by default; add `?no-history` to stream only new lines.
|
||||||
|
- `GET /logs/stream/proxy` streams proxy logs only.
|
||||||
|
- `GET /logs/stream/upstream` streams upstream process logs only.
|
||||||
|
- `GET /logs/stream/{model_id}` streams logs for one model (including IDs with slashes, like `author/model`).
|
||||||
- `/health` - just returns "OK"
|
- `/health` - just returns "OK"
|
||||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||||
- ✅ Customizable
|
- ✅ Customizable
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ go 1.26.1
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/billziss-gh/golib v0.2.0
|
github.com/billziss-gh/golib v0.2.0
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.1
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/klauspost/compress v1.18.5
|
github.com/klauspost/compress v1.18.5
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
@@ -36,6 +37,7 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.45.0 // indirect
|
golang.org/x/crypto v0.45.0 // indirect
|
||||||
golang.org/x/net v0.47.0 // indirect
|
golang.org/x/net v0.47.0 // indirect
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
@@ -77,6 +79,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
|
|||||||
Vendored
+102
@@ -0,0 +1,102 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
|
||||||
|
ErrNotFound = errors.New("item not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
items map[int][]byte
|
||||||
|
order []int
|
||||||
|
size int
|
||||||
|
maxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(maxBytes int) *Cache {
|
||||||
|
return &Cache{
|
||||||
|
items: make(map[int][]byte),
|
||||||
|
order: make([]int, 0),
|
||||||
|
maxSize: maxBytes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Add(id int, data []byte) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
dataSize := len(data)
|
||||||
|
if dataSize > c.maxSize {
|
||||||
|
return ErrExceedsMaxSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// If key already exists, remove old entry from size and order
|
||||||
|
if old, exists := c.items[id]; exists {
|
||||||
|
c.size -= len(old)
|
||||||
|
c.removeOrder(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict oldest (FIFO) until room available
|
||||||
|
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
|
||||||
|
oldestID := c.order[0]
|
||||||
|
c.order = c.order[1:]
|
||||||
|
if evicted, exists := c.items[oldestID]; exists {
|
||||||
|
c.size -= len(evicted)
|
||||||
|
delete(c.items, oldestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.items[id] = data
|
||||||
|
c.order = append(c.order, id)
|
||||||
|
c.size += dataSize
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) removeOrder(id int) {
|
||||||
|
for i, v := range c.order {
|
||||||
|
if v == id {
|
||||||
|
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Get(id int) ([]byte, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
data, exists := c.items[id]
|
||||||
|
if !exists {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Has(id int) bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
_, exists := c.items[id]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Size() int {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
return c.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Clear() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.items = make(map[int][]byte)
|
||||||
|
c.order = c.order[:0]
|
||||||
|
c.size = 0
|
||||||
|
}
|
||||||
Vendored
+130
@@ -0,0 +1,130 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCache_Add(t *testing.T) {
|
||||||
|
t.Run("adds and retrieves item", func(t *testing.T) {
|
||||||
|
c := New(1024)
|
||||||
|
data := []byte("hello")
|
||||||
|
require.NoError(t, c.Add(1, data))
|
||||||
|
|
||||||
|
got, err := c.Get(1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, data, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for oversized item", func(t *testing.T) {
|
||||||
|
c := New(10)
|
||||||
|
err := c.Add(1, make([]byte, 20))
|
||||||
|
assert.ErrorIs(t, err, ErrExceedsMaxSize)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("evicts oldest items to make room", func(t *testing.T) {
|
||||||
|
c := New(100)
|
||||||
|
|
||||||
|
require.NoError(t, c.Add(1, make([]byte, 40)))
|
||||||
|
require.NoError(t, c.Add(2, make([]byte, 40)))
|
||||||
|
// Adding item 3 should evict item 1
|
||||||
|
require.NoError(t, c.Add(3, make([]byte, 40)))
|
||||||
|
|
||||||
|
assert.False(t, c.Has(1))
|
||||||
|
assert.True(t, c.Has(2))
|
||||||
|
assert.True(t, c.Has(3))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwrites existing key", func(t *testing.T) {
|
||||||
|
c := New(100)
|
||||||
|
require.NoError(t, c.Add(1, []byte("old")))
|
||||||
|
require.NoError(t, c.Add(1, []byte("new")))
|
||||||
|
|
||||||
|
got, err := c.Get(1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("new"), got)
|
||||||
|
assert.Equal(t, 3, c.Size())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache_Get(t *testing.T) {
|
||||||
|
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
|
||||||
|
c := New(100)
|
||||||
|
_, err := c.Get(99)
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache_Has(t *testing.T) {
|
||||||
|
t.Run("returns true for existing key", func(t *testing.T) {
|
||||||
|
c := New(100)
|
||||||
|
require.NoError(t, c.Add(1, []byte("data")))
|
||||||
|
assert.True(t, c.Has(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns false for missing key", func(t *testing.T) {
|
||||||
|
c := New(100)
|
||||||
|
assert.False(t, c.Has(1))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache_Size(t *testing.T) {
|
||||||
|
t.Run("tracks byte usage", func(t *testing.T) {
|
||||||
|
c := New(1000)
|
||||||
|
assert.Equal(t, 0, c.Size())
|
||||||
|
|
||||||
|
require.NoError(t, c.Add(1, make([]byte, 100)))
|
||||||
|
assert.Equal(t, 100, c.Size())
|
||||||
|
|
||||||
|
require.NoError(t, c.Add(2, make([]byte, 200)))
|
||||||
|
assert.Equal(t, 300, c.Size())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updates on eviction", func(t *testing.T) {
|
||||||
|
c := New(150)
|
||||||
|
require.NoError(t, c.Add(1, make([]byte, 100)))
|
||||||
|
require.NoError(t, c.Add(2, make([]byte, 100)))
|
||||||
|
|
||||||
|
// Item 1 should be evicted, size = 100
|
||||||
|
assert.Equal(t, 100, c.Size())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache_Clear(t *testing.T) {
|
||||||
|
t.Run("removes all items and resets size", func(t *testing.T) {
|
||||||
|
c := New(1000)
|
||||||
|
require.NoError(t, c.Add(1, []byte("a")))
|
||||||
|
require.NoError(t, c.Add(2, []byte("b")))
|
||||||
|
|
||||||
|
c.Clear()
|
||||||
|
|
||||||
|
assert.Equal(t, 0, c.Size())
|
||||||
|
assert.False(t, c.Has(1))
|
||||||
|
assert.False(t, c.Has(2))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache_Concurrent(t *testing.T) {
|
||||||
|
t.Run("concurrent operations are safe", func(t *testing.T) {
|
||||||
|
c := New(10000)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
key := id*100 + j
|
||||||
|
_ = c.Add(key, []byte("data"))
|
||||||
|
_, _ = c.Get(key)
|
||||||
|
_ = c.Has(key)
|
||||||
|
_ = c.Size()
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -646,9 +646,6 @@ func validateMacro(name string, value any) error {
|
|||||||
// Validate that value is a scalar type
|
// Validate that value is a scalar type
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
if len(v) >= 1024 {
|
|
||||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
|
||||||
}
|
|
||||||
// Check for self-reference
|
// Check for self-reference
|
||||||
macroSlug := fmt.Sprintf("${%s}", name)
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
if strings.Contains(v, macroSlug) {
|
if strings.Contains(v, macroSlug) {
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@ const ProcessStateChangeEventID = 0x01
|
|||||||
const ChatCompletionStatsEventID = 0x02
|
const ChatCompletionStatsEventID = 0x02
|
||||||
const ConfigFileChangedEventID = 0x03
|
const ConfigFileChangedEventID = 0x03
|
||||||
const LogDataEventID = 0x04
|
const LogDataEventID = 0x04
|
||||||
const TokenMetricsEventID = 0x05
|
const ActivityLogEventID = 0x05
|
||||||
const ModelPreloadedEventID = 0x06
|
const ModelPreloadedEventID = 0x06
|
||||||
const InFlightRequestsEventID = 0x07
|
const InFlightRequestsEventID = 0x07
|
||||||
|
|
||||||
|
|||||||
+182
-135
@@ -12,9 +12,11 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/cache"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,37 +44,53 @@ var zstdDecPool = &sync.Pool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
|
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||||
// Returns compressed bytes and the original JSON byte count for logging.
|
// Returns compressed bytes and the original CBOR byte count for logging.
|
||||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||||
jsonBytes, err := json.Marshal(c)
|
cborBytes, err := cbor.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||||
}
|
}
|
||||||
enc := zstdEncPool.Get().(*zstd.Encoder)
|
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||||
defer zstdEncPool.Put(enc)
|
defer zstdEncPool.Put(zenc)
|
||||||
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
|
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decompressCapture decompresses zstd-compressed JSON and returns it.
|
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
|
||||||
func decompressCapture(data []byte) ([]byte, error) {
|
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||||
defer zstdDecPool.Put(dec)
|
defer zstdDecPool.Put(dec)
|
||||||
return dec.DecodeAll(data, nil)
|
cborBytes, err := dec.DecodeAll(data, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||||
|
}
|
||||||
|
var capture ReqRespCapture
|
||||||
|
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||||
|
}
|
||||||
|
return &capture, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
// TokenMetrics holds token usage and performance metrics
|
||||||
type TokenMetrics struct {
|
type TokenMetrics struct {
|
||||||
ID int `json:"id"`
|
CachedTokens int `json:"cache_tokens"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
InputTokens int `json:"input_tokens"`
|
||||||
Model string `json:"model"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
CachedTokens int `json:"cache_tokens"`
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
InputTokens int `json:"input_tokens"`
|
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
}
|
||||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
|
||||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
// ActivityLogEntry represents parsed token statistics from llama-server logs
|
||||||
DurationMs int `json:"duration_ms"`
|
type ActivityLogEntry struct {
|
||||||
HasCapture bool `json:"has_capture"`
|
ID int `json:"id"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
ReqPath string `json:"req_path"`
|
||||||
|
RespContentType string `json:"resp_content_type"`
|
||||||
|
RespStatusCode int `json:"resp_status_code"`
|
||||||
|
Tokens TokenMetrics `json:"tokens"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
HasCapture bool `json:"has_capture"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReqRespCapture struct {
|
type ReqRespCapture struct {
|
||||||
@@ -84,48 +102,45 @@ type ReqRespCapture struct {
|
|||||||
RespBody []byte `json:"resp_body"`
|
RespBody []byte `json:"resp_body"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenMetricsEvent represents a token metrics event
|
// ActivityLogEvent represents a token metrics event
|
||||||
type TokenMetricsEvent struct {
|
type ActivityLogEvent struct {
|
||||||
Metrics TokenMetrics
|
Metrics ActivityLogEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e TokenMetricsEvent) Type() uint32 {
|
func (e ActivityLogEvent) Type() uint32 {
|
||||||
return TokenMetricsEventID // defined in events.go
|
return ActivityLogEventID // defined in events.go
|
||||||
}
|
}
|
||||||
|
|
||||||
// metricsMonitor parses llama-server output for token statistics
|
// metricsMonitor parses llama-server output for token statistics
|
||||||
type metricsMonitor struct {
|
type metricsMonitor struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
metrics []TokenMetrics
|
metrics []ActivityLogEntry
|
||||||
maxMetrics int
|
maxMetrics int
|
||||||
nextID int
|
nextID int
|
||||||
logger *LogMonitor
|
logger *LogMonitor
|
||||||
|
|
||||||
// capture fields
|
// capture fields
|
||||||
enableCaptures bool
|
enableCaptures bool
|
||||||
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
|
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||||
captureOrder []int // track insertion order for FIFO eviction
|
|
||||||
captureSize int // current total compressed size in bytes
|
|
||||||
maxCaptureSize int // max bytes for captures (uncompressed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||||
// capture buffer size in megabytes; 0 disables captures.
|
// capture buffer size in megabytes; 0 disables captures.
|
||||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||||
return &metricsMonitor{
|
mm := &metricsMonitor{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
maxMetrics: maxMetrics,
|
maxMetrics: maxMetrics,
|
||||||
enableCaptures: captureBufferMB > 0,
|
enableCaptures: captureBufferMB > 0,
|
||||||
captures: make(map[int][]byte),
|
|
||||||
captureOrder: make([]int, 0),
|
|
||||||
captureSize: 0,
|
|
||||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
|
||||||
}
|
}
|
||||||
|
if captureBufferMB > 0 {
|
||||||
|
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||||
|
}
|
||||||
|
return mm
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMetrics adds a new metric to the collection and publishes an event.
|
// queueMetrics adds a new metric to the collection without emitting an event.
|
||||||
// Returns the assigned metric ID.
|
// Returns the assigned metric ID. Call emitMetric after capture setup.
|
||||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||||
mp.mu.Lock()
|
mp.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
@@ -135,93 +150,75 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
|||||||
if len(mp.metrics) > mp.maxMetrics {
|
if len(mp.metrics) > mp.maxMetrics {
|
||||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||||
}
|
}
|
||||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
|
||||||
return metric.ID
|
return metric.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||||
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
|
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCapture compresses and stores a capture in the cache.
|
||||||
|
// Returns true if the capture was stored, false otherwise.
|
||||||
|
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||||
if !mp.enableCaptures {
|
if !mp.enableCaptures {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
captureSize := len(compressed)
|
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||||
if captureSize > mp.maxCaptureSize {
|
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||||
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
return false
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
|
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||||
|
|
||||||
mp.mu.Lock()
|
|
||||||
defer mp.mu.Unlock()
|
|
||||||
|
|
||||||
// Evict oldest (FIFO) until room available for the compressed data
|
|
||||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
|
||||||
oldestID := mp.captureOrder[0]
|
|
||||||
mp.captureOrder = mp.captureOrder[1:]
|
|
||||||
if evicted, exists := mp.captures[oldestID]; exists {
|
|
||||||
l := len(evicted)
|
|
||||||
mp.captureSize -= l
|
|
||||||
delete(mp.captures, oldestID)
|
|
||||||
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mp.captures[capture.ID] = compressed
|
|
||||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
|
||||||
mp.captureSize += captureSize
|
|
||||||
|
|
||||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||||
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||||
mp.mu.RLock()
|
if mp.captureCache == nil {
|
||||||
defer mp.mu.RUnlock()
|
return nil, false
|
||||||
|
}
|
||||||
data, exists := mp.captures[id]
|
data, err := mp.captureCache.Get(id)
|
||||||
return data, exists
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return data, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
|
// getCaptureByID decompresses and unmarshals a capture by ID.
|
||||||
// If decompress=false, returns the raw zstd-compressed bytes.
|
// Returns nil if the capture is not found or decompression fails.
|
||||||
// Returns nil if the capture is not found.
|
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||||
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
|
if mp.captureCache == nil {
|
||||||
mp.mu.RLock()
|
return nil
|
||||||
defer mp.mu.RUnlock()
|
}
|
||||||
|
data, exists := mp.getCompressedBytes(id)
|
||||||
data, exists := mp.captures[id]
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !decompress {
|
capture, err := decompressCapture(data)
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
decompressed, err := decompressCapture(data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return decompressed
|
return capture
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMetrics returns a copy of the current metrics
|
// getMetrics returns a copy of the current metrics
|
||||||
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
result := make([]TokenMetrics, len(mp.metrics))
|
result := make([]ActivityLogEntry, len(mp.metrics))
|
||||||
copy(result, mp.metrics)
|
copy(result, mp.metrics)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
@@ -230,22 +227,52 @@ func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
|||||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
return json.Marshal(mp.metrics)
|
|
||||||
|
if mp.captureCache == nil {
|
||||||
|
return json.Marshal(mp.metrics)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy with up-to-date has_capture from cache
|
||||||
|
result := make([]ActivityLogEntry, len(mp.metrics))
|
||||||
|
for i, m := range mp.metrics {
|
||||||
|
m.HasCapture = mp.captureCache.Has(m.ID)
|
||||||
|
result[i] = m
|
||||||
|
}
|
||||||
|
return json.Marshal(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapHandler wraps the proxy handler to extract token metrics
|
// Capture field flags for controlling what is saved in ReqRespCapture.
|
||||||
|
type captureFields uint
|
||||||
|
|
||||||
|
const (
|
||||||
|
captureNone captureFields = 1 << iota
|
||||||
|
captureReqHeaders
|
||||||
|
captureReqBody
|
||||||
|
captureRespHeaders
|
||||||
|
captureRespBody
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
captureReqAll = captureReqHeaders | captureReqBody
|
||||||
|
captureRespAll = captureRespHeaders | captureRespBody
|
||||||
|
captureAll = captureReqAll | captureRespAll
|
||||||
|
)
|
||||||
|
|
||||||
|
// wrapHandler wraps the proxy handler to extract token metrics.
|
||||||
|
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
|
||||||
// if wrapHandler returns an error it is safe to assume that no
|
// if wrapHandler returns an error it is safe to assume that no
|
||||||
// data was sent to the client
|
// data was sent to the client
|
||||||
func (mp *metricsMonitor) wrapHandler(
|
func (mp *metricsMonitor) wrapHandler(
|
||||||
modelID string,
|
modelID string,
|
||||||
writer gin.ResponseWriter,
|
writer gin.ResponseWriter,
|
||||||
request *http.Request,
|
request *http.Request,
|
||||||
|
captureFields captureFields,
|
||||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||||
) error {
|
) error {
|
||||||
// Capture request body and headers if captures enabled
|
// Capture request body and headers if captures enabled
|
||||||
var reqBody []byte
|
var reqBody []byte
|
||||||
var reqHeaders map[string]string
|
var reqHeaders map[string]string
|
||||||
if mp.enableCaptures {
|
if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
|
||||||
if request.Body != nil {
|
if request.Body != nil {
|
||||||
var err error
|
var err error
|
||||||
reqBody, err = io.ReadAll(request.Body)
|
reqBody, err = io.ReadAll(request.Body)
|
||||||
@@ -255,6 +282,8 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
request.Body.Close()
|
request.Body.Close()
|
||||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
|
||||||
reqHeaders = make(map[string]string)
|
reqHeaders = make(map[string]string)
|
||||||
for key, values := range request.Header {
|
for key, values := range request.Header {
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
@@ -278,22 +307,28 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
// after this point we have to assume that data was sent to the client
|
// after this point we have to assume that data was sent to the client
|
||||||
// and we can only log errors but not send them to clients
|
// and we can only log errors but not send them to clients
|
||||||
|
|
||||||
if recorder.Status() != http.StatusOK {
|
// Initialize default metrics - recorded for every request
|
||||||
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
tm := ActivityLogEntry{
|
||||||
return nil
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
ReqPath: request.URL.Path,
|
||||||
|
RespContentType: recorder.Header().Get("Content-Type"),
|
||||||
|
RespStatusCode: recorder.Status(),
|
||||||
|
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize default metrics - these will always be recorded
|
if recorder.Status() != http.StatusOK {
|
||||||
tm := TokenMetrics{
|
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||||
Timestamp: time.Now(),
|
tm.ID = mp.queueMetrics(tm)
|
||||||
Model: modelID,
|
mp.emitMetric(tm)
|
||||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
body := recorder.body.Bytes()
|
body := recorder.body.Bytes()
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||||
mp.addMetrics(tm)
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
mp.emitMetric(tm)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,7 +338,8 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
body, err = decompressBody(body, encoding)
|
body, err = decompressBody(body, encoding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
mp.addMetrics(tm)
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
mp.emitMetric(tm)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -311,7 +347,8 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
} else {
|
} else {
|
||||||
tm = parsed
|
tm.Tokens = parsed.Tokens
|
||||||
|
tm.DurationMs = parsed.DurationMs
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if gjson.ValidBytes(body) {
|
if gjson.ValidBytes(body) {
|
||||||
@@ -331,7 +368,8 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
} else {
|
} else {
|
||||||
tm = parsedMetrics
|
tm.Tokens = parsedMetrics.Tokens
|
||||||
|
tm.DurationMs = parsedMetrics.DurationMs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -342,39 +380,50 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
// Build capture if enabled and determine if it will be stored
|
// Build capture if enabled and determine if it will be stored
|
||||||
var capture *ReqRespCapture
|
var capture *ReqRespCapture
|
||||||
if mp.enableCaptures {
|
if mp.enableCaptures {
|
||||||
respHeaders := make(map[string]string)
|
var respHeaders map[string]string
|
||||||
for key, values := range recorder.Header() {
|
var respBody []byte
|
||||||
if len(values) > 0 {
|
if (captureFields & captureRespHeaders) != 0 {
|
||||||
respHeaders[key] = values[0]
|
respHeaders = make(map[string]string)
|
||||||
|
for key, values := range recorder.Header() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
respHeaders[key] = values[0]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
redactHeaders(respHeaders)
|
||||||
|
delete(respHeaders, "Content-Encoding")
|
||||||
|
}
|
||||||
|
if (captureFields & captureRespBody) != 0 {
|
||||||
|
respBody = body
|
||||||
}
|
}
|
||||||
redactHeaders(respHeaders)
|
|
||||||
delete(respHeaders, "Content-Encoding")
|
|
||||||
capture = &ReqRespCapture{
|
capture = &ReqRespCapture{
|
||||||
ReqPath: request.URL.Path,
|
ReqPath: request.URL.Path,
|
||||||
ReqHeaders: reqHeaders,
|
ReqHeaders: reqHeaders,
|
||||||
ReqBody: reqBody,
|
ReqBody: reqBody,
|
||||||
RespHeaders: respHeaders,
|
RespHeaders: respHeaders,
|
||||||
RespBody: body,
|
RespBody: respBody,
|
||||||
}
|
|
||||||
compressed, _, err := compressCapture(capture)
|
|
||||||
if err == nil && len(compressed) <= mp.maxCaptureSize {
|
|
||||||
tm.HasCapture = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
metricID := mp.addMetrics(tm)
|
metricID := mp.queueMetrics(tm)
|
||||||
|
tm.ID = metricID
|
||||||
|
|
||||||
// Store capture if enabled
|
// Store capture if enabled
|
||||||
if capture != nil {
|
if capture != nil {
|
||||||
capture.ID = metricID
|
capture.ID = metricID
|
||||||
mp.addCapture(*capture)
|
if mp.addCapture(*capture) {
|
||||||
|
tm.HasCapture = true
|
||||||
|
mp.mu.Lock()
|
||||||
|
mp.metrics[len(mp.metrics)-1].HasCapture = true
|
||||||
|
mp.mu.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mp.emitMetric(tm)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||||
// Iterate **backwards** through the body looking for the data payload with
|
// Iterate **backwards** through the body looking for the data payload with
|
||||||
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||||
|
|
||||||
@@ -428,10 +477,10 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||||
|
|
||||||
// default values
|
// default values
|
||||||
@@ -481,15 +530,17 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return TokenMetrics{
|
return ActivityLogEntry{
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
Model: modelID,
|
Model: modelID,
|
||||||
CachedTokens: cachedTokens,
|
Tokens: TokenMetrics{
|
||||||
InputTokens: inputTokens,
|
CachedTokens: cachedTokens,
|
||||||
OutputTokens: outputTokens,
|
InputTokens: inputTokens,
|
||||||
PromptPerSecond: promptPerSecond,
|
OutputTokens: outputTokens,
|
||||||
TokensPerSecond: tokensPerSecond,
|
PromptPerSecond: promptPerSecond,
|
||||||
DurationMs: durationMs,
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
},
|
||||||
|
DurationMs: durationMs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,15 +578,11 @@ func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
|||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
body: bodyBuffer,
|
body: bodyBuffer,
|
||||||
tee: io.MultiWriter(w, bodyBuffer),
|
tee: io.MultiWriter(w, bodyBuffer),
|
||||||
|
start: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||||
if w.start.IsZero() {
|
|
||||||
w.start = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single write operation that writes to both the response and buffer
|
|
||||||
return w.tee.Write(b)
|
return w.tee.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+344
-175
@@ -12,8 +12,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/cache"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -22,27 +24,29 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := ActivityLogEntry{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
InputTokens: 100,
|
Tokens: TokenMetrics{
|
||||||
OutputTokens: 50,
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mm.addMetrics(metric)
|
mm.queueMetrics(metric)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, 0, metrics[0].ID)
|
assert.Equal(t, 0, metrics[0].ID)
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
mm.queueMetrics(ActivityLogEntry{Model: "model"})
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
@@ -57,9 +61,11 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
|
|
||||||
// Add 5 metrics
|
// Add 5 metrics
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
mm.addMetrics(TokenMetrics{
|
mm.queueMetrics(ActivityLogEntry{
|
||||||
Model: "model",
|
Model: "model",
|
||||||
InputTokens: i,
|
Tokens: TokenMetrics{
|
||||||
|
InputTokens: i,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,29 +78,32 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
assert.Equal(t, 4, metrics[2].ID)
|
assert.Equal(t, 4, metrics[2].ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
t.Run("emits ActivityLogEvent", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
receivedEvent := make(chan ActivityLogEvent, 1)
|
||||||
cancel := event.On(func(e TokenMetricsEvent) {
|
cancel := event.On(func(e ActivityLogEvent) {
|
||||||
receivedEvent <- e
|
receivedEvent <- e
|
||||||
})
|
})
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := ActivityLogEntry{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
InputTokens: 100,
|
Tokens: TokenMetrics{
|
||||||
OutputTokens: 50,
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mm.addMetrics(metric)
|
mm.queueMetrics(metric)
|
||||||
|
mm.emitMetric(metric)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case evt := <-receivedEvent:
|
case evt := <-receivedEvent:
|
||||||
assert.Equal(t, 0, evt.Metrics.ID)
|
assert.Equal(t, 0, evt.Metrics.ID)
|
||||||
assert.Equal(t, "test-model", evt.Metrics.Model)
|
assert.Equal(t, "test-model", evt.Metrics.Model)
|
||||||
assert.Equal(t, 100, evt.Metrics.InputTokens)
|
assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, evt.Metrics.OutputTokens)
|
assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens)
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
t.Fatal("timeout waiting for event")
|
t.Fatal("timeout waiting for event")
|
||||||
}
|
}
|
||||||
@@ -111,8 +120,8 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
mm.queueMetrics(ActivityLogEntry{Model: "model1"})
|
||||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
mm.queueMetrics(ActivityLogEntry{Model: "model2"})
|
||||||
|
|
||||||
metrics1 := mm.getMetrics()
|
metrics1 := mm.getMetrics()
|
||||||
metrics2 := mm.getMetrics()
|
metrics2 := mm.getMetrics()
|
||||||
@@ -135,7 +144,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, jsonData)
|
assert.NotNil(t, jsonData)
|
||||||
|
|
||||||
var metrics []TokenMetrics
|
var metrics []ActivityLogEntry
|
||||||
err = json.Unmarshal(jsonData, &metrics)
|
err = json.Unmarshal(jsonData, &metrics)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 0, len(metrics))
|
||||||
@@ -143,23 +152,27 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
mm.addMetrics(TokenMetrics{
|
mm.queueMetrics(ActivityLogEntry{
|
||||||
Model: "model1",
|
Model: "model1",
|
||||||
InputTokens: 100,
|
Tokens: TokenMetrics{
|
||||||
OutputTokens: 50,
|
InputTokens: 100,
|
||||||
TokensPerSecond: 25.5,
|
OutputTokens: 50,
|
||||||
|
TokensPerSecond: 25.5,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
mm.addMetrics(TokenMetrics{
|
mm.queueMetrics(ActivityLogEntry{
|
||||||
Model: "model2",
|
Model: "model2",
|
||||||
InputTokens: 200,
|
Tokens: TokenMetrics{
|
||||||
OutputTokens: 100,
|
InputTokens: 200,
|
||||||
TokensPerSecond: 30.0,
|
OutputTokens: 100,
|
||||||
|
TokensPerSecond: 30.0,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
jsonData, err := mm.getMetricsJSON()
|
jsonData, err := mm.getMetricsJSON()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
var metrics []TokenMetrics
|
var metrics []ActivityLogEntry
|
||||||
err = json.Unmarshal(jsonData, &metrics)
|
err = json.Unmarshal(jsonData, &metrics)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(metrics))
|
assert.Equal(t, 2, len(metrics))
|
||||||
@@ -190,14 +203,14 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("successful request with timings data", func(t *testing.T) {
|
t.Run("successful request with timings data", func(t *testing.T) {
|
||||||
@@ -226,17 +239,17 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||||
assert.Equal(t, 20, metrics[0].CachedTokens)
|
assert.Equal(t, 20, metrics[0].Tokens.CachedTokens)
|
||||||
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
|
assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond)
|
||||||
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
|
assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond)
|
||||||
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
|
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -265,18 +278,18 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
// When timings data is present, it takes precedence
|
// When timings data is present, it takes precedence
|
||||||
assert.Equal(t, 10, metrics[0].InputTokens)
|
assert.Equal(t, 10, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 20, metrics[0].OutputTokens)
|
assert.Equal(t, 20, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
t.Run("non-OK status code records partial metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
@@ -289,11 +302,16 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, "/test", metrics[0].ReqPath)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode)
|
||||||
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||||
@@ -308,14 +326,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||||
@@ -332,14 +350,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||||
@@ -354,7 +372,7 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.Equal(t, expectedErr, err)
|
assert.Equal(t, expectedErr, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
@@ -377,14 +395,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
||||||
@@ -416,17 +434,17 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 150, metrics[0].InputTokens)
|
assert.Equal(t, 150, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
|
||||||
assert.Equal(t, 30, metrics[0].CachedTokens)
|
assert.Equal(t, 30, metrics[0].Tokens.CachedTokens)
|
||||||
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
|
assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond)
|
||||||
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
|
assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond)
|
||||||
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -446,14 +464,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,15 +490,11 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
|||||||
assert.Equal(t, string(testData), rec.Body.String())
|
assert.Equal(t, string(testData), rec.Body.String())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sets start time on first write", func(t *testing.T) {
|
t.Run("sets start time on creation", func(t *testing.T) {
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
copier := newBodyCopier(ginCtx.Writer)
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
assert.True(t, copier.StartTime().IsZero())
|
|
||||||
|
|
||||||
copier.Write([]byte("test"))
|
|
||||||
|
|
||||||
assert.False(t, copier.StartTime().IsZero())
|
assert.False(t, copier.StartTime().IsZero())
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -507,7 +521,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
t.Run("concurrent queueMetrics is safe", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -519,10 +533,12 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
|||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for j := 0; j < metricsPerGoroutine; j++ {
|
for j := 0; j < metricsPerGoroutine; j++ {
|
||||||
mm.addMetrics(TokenMetrics{
|
mm.queueMetrics(ActivityLogEntry{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
InputTokens: id*1000 + j,
|
Tokens: TokenMetrics{
|
||||||
OutputTokens: j,
|
InputTokens: id*1000 + j,
|
||||||
|
OutputTokens: j,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
@@ -542,7 +558,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
|||||||
// Writer goroutine
|
// Writer goroutine
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < 50; i++ {
|
for i := 0; i < 50; i++ {
|
||||||
mm.addMetrics(TokenMetrics{Model: "test-model"})
|
mm.queueMetrics(ActivityLogEntry{Model: "test-model"})
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
}
|
}
|
||||||
done <- true
|
done <- true
|
||||||
@@ -586,10 +602,10 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
|||||||
|
|
||||||
metrics, err := parseMetrics("test-model", start, usage, timings)
|
metrics, err := parseMetrics("test-model", start, usage, timings)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 5, metrics.InputTokens)
|
assert.Equal(t, 5, metrics.Tokens.InputTokens)
|
||||||
assert.Equal(t, 1, metrics.OutputTokens)
|
assert.Equal(t, 1, metrics.Tokens.OutputTokens)
|
||||||
assert.Equal(t, 10.0, metrics.PromptPerSecond)
|
assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond)
|
||||||
assert.Equal(t, 2.0, metrics.TokensPerSecond)
|
assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond)
|
||||||
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
|
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -623,14 +639,14 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
// Should use timings values, not usage values
|
// Should use timings values, not usage values
|
||||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||||
@@ -658,12 +674,12 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
|
assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -693,13 +709,13 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||||
@@ -722,14 +738,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
|
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
|
||||||
@@ -751,14 +767,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 17, metrics[0].InputTokens)
|
assert.Equal(t, 17, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 23, metrics[0].OutputTokens)
|
assert.Equal(t, 23, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||||
@@ -777,14 +793,14 @@ data: [DONE]
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -792,20 +808,22 @@ data: [DONE]
|
|||||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := ActivityLogEntry{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
CachedTokens: 100,
|
Tokens: TokenMetrics{
|
||||||
InputTokens: 500,
|
CachedTokens: 100,
|
||||||
OutputTokens: 250,
|
InputTokens: 500,
|
||||||
PromptPerSecond: 1200.5,
|
OutputTokens: 250,
|
||||||
TokensPerSecond: 45.8,
|
PromptPerSecond: 1200.5,
|
||||||
DurationMs: 5000,
|
TokensPerSecond: 45.8,
|
||||||
Timestamp: time.Now(),
|
},
|
||||||
|
DurationMs: 5000,
|
||||||
|
Timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
mm.addMetrics(metric)
|
mm.queueMetrics(metric)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -813,20 +831,22 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
|||||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := ActivityLogEntry{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
CachedTokens: 100,
|
Tokens: TokenMetrics{
|
||||||
InputTokens: 500,
|
CachedTokens: 100,
|
||||||
OutputTokens: 250,
|
InputTokens: 500,
|
||||||
PromptPerSecond: 1200.5,
|
OutputTokens: 250,
|
||||||
TokensPerSecond: 45.8,
|
PromptPerSecond: 1200.5,
|
||||||
DurationMs: 5000,
|
TokensPerSecond: 45.8,
|
||||||
Timestamp: time.Now(),
|
},
|
||||||
|
DurationMs: 5000,
|
||||||
|
Timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
mm.addMetrics(metric)
|
mm.queueMetrics(metric)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -855,14 +875,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("deflate encoded response", func(t *testing.T) {
|
t.Run("deflate encoded response", func(t *testing.T) {
|
||||||
@@ -889,14 +909,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 200, metrics[0].InputTokens)
|
assert.Equal(t, 200, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||||
@@ -917,14 +937,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err) // Should not return error, just log warning
|
assert.NoError(t, err) // Should not return error, just log warning
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, "test-model", metrics[0].Model)
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||||
@@ -944,13 +964,13 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
assert.Equal(t, 300, metrics[0].InputTokens)
|
assert.Equal(t, 300, metrics[0].Tokens.InputTokens)
|
||||||
assert.Equal(t, 100, metrics[0].OutputTokens)
|
assert.Equal(t, 100, metrics[0].Tokens.OutputTokens)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -989,7 +1009,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
|||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
// Should not store capture
|
// Should not store capture
|
||||||
assert.Nil(t, mm.getCaptureByID(0, false))
|
assert.Nil(t, mm.getCaptureByID(0))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("adds capture when enabled", func(t *testing.T) {
|
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||||
@@ -1002,22 +1022,18 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
retrieved := mm.getCaptureByID(0, true)
|
captured := mm.getCaptureByID(0)
|
||||||
assert.NotNil(t, retrieved)
|
assert.NotNil(t, captured)
|
||||||
|
assert.Equal(t, 0, captured.ID)
|
||||||
var decoded ReqRespCapture
|
assert.Equal(t, []byte("test request"), captured.ReqBody)
|
||||||
err := json.Unmarshal(retrieved, &decoded)
|
assert.Equal(t, []byte("test response"), captured.RespBody)
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 0, decoded.ID)
|
|
||||||
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
|
||||||
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
|
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
|
||||||
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
|
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
|
||||||
mm.maxCaptureSize = 450
|
mm.captureCache = cache.New(450)
|
||||||
|
|
||||||
// Use random-looking data that doesn't compress well with zstd
|
// Use random-looking data that doesn't compress well with zstd
|
||||||
rng := rand.New(rand.NewSource(42))
|
rng := rand.New(rand.NewSource(42))
|
||||||
@@ -1033,16 +1049,14 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
|||||||
// Adding capture3 should evict capture1
|
// Adding capture3 should evict capture1
|
||||||
mm.addCapture(capture3)
|
mm.addCapture(capture3)
|
||||||
|
|
||||||
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
|
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
|
||||||
retrieved := mm.getCaptureByID(1, true)
|
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
|
||||||
assert.NotNil(t, retrieved, "capture 1 should exist")
|
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
|
||||||
retrieved = mm.getCaptureByID(2, true)
|
|
||||||
assert.NotNil(t, retrieved, "capture 2 should exist")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("skips capture larger than max size", func(t *testing.T) {
|
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
mm.maxCaptureSize = 100
|
mm.captureCache = cache.New(100)
|
||||||
|
|
||||||
// Use random data that doesn't compress well to create an oversized capture
|
// Use random data that doesn't compress well to create an oversized capture
|
||||||
rng := rand.New(rand.NewSource(99))
|
rng := rand.New(rand.NewSource(99))
|
||||||
@@ -1050,7 +1064,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
|||||||
rng.Read(largeCapture.ReqBody)
|
rng.Read(largeCapture.ReqBody)
|
||||||
mm.addCapture(largeCapture)
|
mm.addCapture(largeCapture)
|
||||||
|
|
||||||
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
|
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1058,7 +1072,7 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
|||||||
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
assert.Nil(t, mm.getCaptureByID(999, false))
|
assert.Nil(t, mm.getCaptureByID(999))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns decompressed capture by ID", func(t *testing.T) {
|
t.Run("returns decompressed capture by ID", func(t *testing.T) {
|
||||||
@@ -1071,18 +1085,14 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
retrieved := mm.getCaptureByID(42, true)
|
captured := mm.getCaptureByID(42)
|
||||||
assert.NotNil(t, retrieved)
|
assert.NotNil(t, captured)
|
||||||
|
assert.Equal(t, 42, captured.ID)
|
||||||
var decoded ReqRespCapture
|
assert.Equal(t, []byte("test request"), captured.ReqBody)
|
||||||
err := json.Unmarshal(retrieved, &decoded)
|
assert.Equal(t, []byte("test response"), captured.RespBody)
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 42, decoded.ID)
|
|
||||||
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
|
||||||
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
|
t.Run("stores data as compressed bytes", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
capture := ReqRespCapture{
|
capture := ReqRespCapture{
|
||||||
@@ -1092,10 +1102,12 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
compressed := mm.getCaptureByID(42, false)
|
compressed, exists := mm.getCompressedBytes(42)
|
||||||
|
assert.True(t, exists)
|
||||||
assert.NotNil(t, compressed)
|
assert.NotNil(t, compressed)
|
||||||
// Compressed data should not be valid JSON (it's zstd-compressed)
|
// Compressed data should not be valid CBOR (it's zstd-compressed)
|
||||||
assert.False(t, gjson.ValidBytes(compressed))
|
var decoded ReqRespCapture
|
||||||
|
assert.Error(t, cbor.Unmarshal(compressed, &decoded))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1164,7 +1176,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Check metric was recorded
|
// Check metric was recorded
|
||||||
@@ -1173,12 +1185,8 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
|||||||
metricID := metrics[0].ID
|
metricID := metrics[0].ID
|
||||||
|
|
||||||
// Check capture was stored with same ID (decompressed)
|
// Check capture was stored with same ID (decompressed)
|
||||||
captureData := mm.getCaptureByID(metricID, true)
|
capture := mm.getCaptureByID(metricID)
|
||||||
assert.NotNil(t, captureData)
|
assert.NotNil(t, capture)
|
||||||
|
|
||||||
var capture ReqRespCapture
|
|
||||||
err = json.Unmarshal(captureData, &capture)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, metricID, capture.ID)
|
assert.Equal(t, metricID, capture.ID)
|
||||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||||
@@ -1206,7 +1214,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Metrics should still be recorded
|
// Metrics should still be recorded
|
||||||
@@ -1214,7 +1222,168 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
|
||||||
// But no capture
|
// But no capture
|
||||||
capture := mm.getCaptureByID(metrics[0].ID, false)
|
assert.Nil(t, mm.getCaptureByID(metrics[0].ID))
|
||||||
assert.Nil(t, capture)
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) {
|
||||||
|
requestBody := `{"model": "test"}`
|
||||||
|
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Custom", "header-value")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("only request headers", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer secret")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.Nil(t, capture.RespHeaders)
|
||||||
|
assert.Nil(t, capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("only request body", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Nil(t, capture.ReqHeaders)
|
||||||
|
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||||
|
assert.Nil(t, capture.RespHeaders)
|
||||||
|
assert.Nil(t, capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("only response headers", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Nil(t, capture.ReqHeaders)
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||||
|
assert.Nil(t, capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("only response body", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Nil(t, capture.ReqHeaders)
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.Nil(t, capture.RespHeaders)
|
||||||
|
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("captureReqAll", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer secret")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||||
|
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||||
|
assert.Nil(t, capture.RespHeaders)
|
||||||
|
assert.Nil(t, capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("captureRespAll", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Nil(t, capture.ReqHeaders)
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||||
|
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no flags", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Nil(t, capture.ReqHeaders)
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.Nil(t, capture.RespHeaders)
|
||||||
|
assert.Nil(t, capture.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mixed flags req headers and resp body", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 100)
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer secret")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.Nil(t, capture.RespHeaders)
|
||||||
|
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+325
-275
@@ -332,41 +332,77 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
// Protected routes use pm.apiKeyAuth() middleware
|
// Protected routes use pm.apiKeyAuth() middleware
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
llmHandler := pm.mkProxyJSONHandler(captureAll)
|
||||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
|
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
// Support anthropic count_tokens API (Also added in the above PR)
|
// Support anthropic count_tokens API (Also added in the above PR)
|
||||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
|
|
||||||
// Support embeddings and reranking
|
// Support embeddings and reranking
|
||||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
|
|
||||||
// llama-server's /reranking endpoint + aliases
|
// llama-server's /reranking endpoint + aliases
|
||||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
|
|
||||||
// llama-server's /infill endpoint for code infilling
|
// llama-server's /infill endpoint for code infilling
|
||||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
|
|
||||||
// llama-server's /completion endpoint
|
// llama-server's /completion endpoint
|
||||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST(
|
||||||
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
"/v1/audio/speech",
|
||||||
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
||||||
|
)
|
||||||
|
pm.ginEngine.POST(
|
||||||
|
"/v1/audio/voices",
|
||||||
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll),
|
||||||
|
)
|
||||||
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
|
||||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST(
|
||||||
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
"/v1/audio/transcriptions",
|
||||||
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody),
|
||||||
|
)
|
||||||
|
pm.ginEngine.POST(
|
||||||
|
"/v1/images/generations",
|
||||||
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
||||||
|
)
|
||||||
|
|
||||||
|
pm.ginEngine.POST(
|
||||||
|
"/v1/images/edits",
|
||||||
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders),
|
||||||
|
)
|
||||||
|
|
||||||
// sd.cpp /sdapi/v1 endpoints
|
// sd.cpp /sdapi/v1 endpoints
|
||||||
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/sdapi/v1/txt2img",
|
||||||
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
|
||||||
|
)
|
||||||
|
pm.ginEngine.POST("/sdapi/v1/img2img",
|
||||||
|
pm.apiKeyAuth(),
|
||||||
|
pm.trackInflight(),
|
||||||
|
pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders),
|
||||||
|
)
|
||||||
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||||
@@ -647,7 +683,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||||
|
|
||||||
if !modelFound {
|
if !modelFound {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
pm.sendErrorResponse(c, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -686,7 +722,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
|
|
||||||
// attempt to record metrics if it is a POST request
|
// attempt to record metrics if it is a POST request
|
||||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
return
|
return
|
||||||
@@ -700,280 +736,294 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) {
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
return func(c *gin.Context) {
|
||||||
if err != nil {
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
if err != nil {
|
||||||
return
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||||
}
|
return
|
||||||
|
|
||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
||||||
if requestedModel == "" {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for a matching local model first
|
|
||||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
||||||
|
|
||||||
modelID, found := pm.config.RealModelName(requestedModel)
|
|
||||||
if found {
|
|
||||||
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
|
||||||
if pm.matrix != nil {
|
|
||||||
localHandler = pm.matrix.ProxyRequest
|
|
||||||
} else {
|
|
||||||
processGroup, err := pm.swapProcessGroup(modelID)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
localHandler = processGroup.ProxyRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
useModelName := pm.config.Models[modelID].UseModelName
|
if requestedModel == "" {
|
||||||
if useModelName != "" {
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
return
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #174 strip parameters from the JSON body
|
// Look for a matching local model first
|
||||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
if err != nil { // just log it and continue
|
|
||||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
} else {
|
if found {
|
||||||
|
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
||||||
|
if pm.matrix != nil {
|
||||||
|
localHandler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localHandler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
|
useModelName := pm.config.Models[modelID].UseModelName
|
||||||
|
if useModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #174 strip parameters from the JSON body
|
||||||
|
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||||
|
if err != nil { // just log it and continue
|
||||||
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||||
|
} else {
|
||||||
|
for _, param := range stripParams {
|
||||||
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||||
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #453 set/override parameters in the JSON body
|
||||||
|
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||||
|
for _, key := range setParamKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
||||||
|
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
||||||
|
for _, key := range setParamsByIDKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
nextHandler = localHandler
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
modelID = requestedModel
|
||||||
|
|
||||||
|
// issue #453 apply filters for peer requests
|
||||||
|
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||||
|
|
||||||
|
// Apply stripParams - remove specified parameters from request
|
||||||
|
stripParams := peerFilters.SanitizedStripParams()
|
||||||
for _, param := range stripParams {
|
for _, param := range stripParams {
|
||||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply setParams - set/override specified parameters in request
|
||||||
|
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||||
|
for _, key := range setParamKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
|
c.Request.Header.Del("transfer-encoding")
|
||||||
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
|
|
||||||
|
// issue #366 extract values that downstream handlers may need
|
||||||
|
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||||
|
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||||
|
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
|
||||||
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mkPostFormHandler creates a POST form handler for inference backends
|
||||||
|
// with a custom captureFields to filter out large binary requests or responses.
|
||||||
|
func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Parse multipart form
|
||||||
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get model parameter from the form
|
||||||
|
requestedModel := c.Request.FormValue("model")
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for a matching local model first, then check peers
|
||||||
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
|
var useModelName string
|
||||||
|
|
||||||
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if found {
|
||||||
|
if pm.matrix != nil {
|
||||||
|
nextHandler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
useModelName = pm.config.Models[modelID].UseModelName
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
modelID = requestedModel
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
|
// Copy all form values
|
||||||
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
|
for _, value := range values {
|
||||||
|
fieldValue := value
|
||||||
|
// If this is the model field and we have a profile, use just the model name
|
||||||
|
if key == "model" {
|
||||||
|
// # issue #69 allow custom model names to be sent to upstream
|
||||||
|
if useModelName != "" {
|
||||||
|
fieldValue = useModelName
|
||||||
|
} else {
|
||||||
|
fieldValue = requestedModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field, err := multipartWriter.CreateFormField(key)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #453 set/override parameters in the JSON body
|
// Copy all files from the original request
|
||||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||||
for _, key := range setParamKeys {
|
for _, fileHeader := range fileHeaders {
|
||||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
if err != nil {
|
||||||
if err != nil {
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
|
||||||
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
|
||||||
for _, key := range setParamsByIDKeys {
|
|
||||||
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
||||||
nextHandler = localHandler
|
|
||||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
||||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
||||||
modelID = requestedModel
|
|
||||||
|
|
||||||
// issue #453 apply filters for peer requests
|
|
||||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
|
||||||
|
|
||||||
// Apply stripParams - remove specified parameters from request
|
|
||||||
stripParams := peerFilters.SanitizedStripParams()
|
|
||||||
for _, param := range stripParams {
|
|
||||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
|
||||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply setParams - set/override specified parameters in request
|
|
||||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
|
||||||
for _, key := range setParamKeys {
|
|
||||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
nextHandler = pm.peerProxy.ProxyRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
if nextHandler == nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
|
||||||
c.Request.Header.Del("transfer-encoding")
|
|
||||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
|
||||||
c.Request.ContentLength = int64(len(bodyBytes))
|
|
||||||
|
|
||||||
// issue #366 extract values that downstream handlers may need
|
|
||||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
|
||||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
|
||||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
|
||||||
|
|
||||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
|
||||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
|
||||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|
||||||
// Parse multipart form
|
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get model parameter from the form
|
|
||||||
requestedModel := c.Request.FormValue("model")
|
|
||||||
if requestedModel == "" {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for a matching local model first, then check peers
|
|
||||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
|
||||||
var useModelName string
|
|
||||||
|
|
||||||
modelID, found := pm.config.RealModelName(requestedModel)
|
|
||||||
if found {
|
|
||||||
if pm.matrix != nil {
|
|
||||||
nextHandler = pm.matrix.ProxyRequest
|
|
||||||
} else {
|
|
||||||
processGroup, err := pm.swapProcessGroup(modelID)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nextHandler = processGroup.ProxyRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
useModelName = pm.config.Models[modelID].UseModelName
|
|
||||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
|
||||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
|
||||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
|
||||||
modelID = requestedModel
|
|
||||||
nextHandler = pm.peerProxy.ProxyRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
if nextHandler == nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Copy all form values
|
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
|
||||||
for _, value := range values {
|
|
||||||
fieldValue := value
|
|
||||||
// If this is the model field and we have a profile, use just the model name
|
|
||||||
if key == "model" {
|
|
||||||
// # issue #69 allow custom model names to be sent to upstream
|
|
||||||
if useModelName != "" {
|
|
||||||
fieldValue = useModelName
|
|
||||||
} else {
|
|
||||||
fieldValue = requestedModel
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
field, err := multipartWriter.CreateFormField(key)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all files from the original request
|
file, err := fileHeader.Open()
|
||||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
if err != nil {
|
||||||
for _, fileHeader := range fileHeaders {
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
return
|
||||||
if err != nil {
|
}
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := fileHeader.Open()
|
if _, err = io.Copy(formFile, file); err != nil {
|
||||||
if err != nil {
|
file.Close()
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = io.Copy(formFile, file); err != nil {
|
|
||||||
file.Close()
|
file.Close()
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the multipart writer to finalize the form
|
||||||
|
if err := multipartWriter.Close(); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request with the reconstructed form data
|
||||||
|
modifiedReq, err := http.NewRequestWithContext(
|
||||||
|
c.Request.Context(),
|
||||||
|
c.Request.Method,
|
||||||
|
c.Request.URL.String(),
|
||||||
|
&requestBuffer,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the headers from the original request
|
||||||
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// set the content length of the body
|
||||||
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||||
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
|
// Use the modified request for proxying
|
||||||
|
if pm.metricsMonitor != nil {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
file.Close()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the multipart writer to finalize the form
|
|
||||||
if err := multipartWriter.Close(); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new request with the reconstructed form data
|
|
||||||
modifiedReq, err := http.NewRequestWithContext(
|
|
||||||
c.Request.Context(),
|
|
||||||
c.Request.Method,
|
|
||||||
c.Request.URL.String(),
|
|
||||||
&requestBuffer,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy the headers from the original request
|
|
||||||
modifiedReq.Header = c.Request.Header.Clone()
|
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
|
||||||
|
|
||||||
// set the content length of the body
|
|
||||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
|
||||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
|
||||||
|
|
||||||
// Use the modified request for proxying
|
|
||||||
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||||
|
|||||||
+10
-20
@@ -158,7 +158,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sendMetrics := func(metrics []TokenMetrics) {
|
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||||
jsonData, err := json.Marshal(metrics)
|
jsonData, err := json.Marshal(metrics)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
select {
|
select {
|
||||||
@@ -205,8 +205,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
/**
|
/**
|
||||||
* Send Metrics data
|
* Send Metrics data
|
||||||
*/
|
*/
|
||||||
defer event.On(func(e TokenMetricsEvent) {
|
defer event.On(func(e ActivityLogEvent) {
|
||||||
sendMetrics([]TokenMetrics{e.Metrics})
|
sendMetrics([]ActivityLogEntry{e.Metrics})
|
||||||
})()
|
})()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -290,26 +290,16 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, exists := pm.metricsMonitor.getCompressedBytes(id)
|
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||||
if !exists {
|
if capture == nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Header("Vary", "Accept-Encoding")
|
jsonBytes, err := json.Marshal(capture)
|
||||||
|
if err != nil {
|
||||||
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
|
||||||
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
|
return
|
||||||
|
|
||||||
if hasZstd {
|
|
||||||
c.Header("Content-Encoding", "zstd")
|
|
||||||
c.Data(http.StatusOK, "application/json", data)
|
|
||||||
} else {
|
|
||||||
decompressed, err := decompressCapture(data)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Data(http.StatusOK, "application/json", decompressed)
|
|
||||||
}
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", jsonBytes)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,13 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||||
|
|
||||||
|
// Handle case where query string might be included in the parameter
|
||||||
|
// (can happen with catch-all routes on some versions/setups)
|
||||||
|
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
|
||||||
|
logMonitorId = logMonitorId[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
@@ -100,6 +107,12 @@ func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
|||||||
return process.Logger(), nil
|
return process.Logger(), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// also check the matrix when processGroups doesn't contain the model
|
||||||
|
if pm.matrix != nil {
|
||||||
|
if process, found := pm.matrix.GetProcess(name); found {
|
||||||
|
return process.Logger(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||||
|
|||||||
@@ -0,0 +1,173 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogMonitorIdQueryParameterStripping(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "upstream without query param",
|
||||||
|
input: "upstream",
|
||||||
|
expected: "upstream",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upstream with query param",
|
||||||
|
input: "upstream?no-history",
|
||||||
|
expected: "upstream",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy with multiple query params",
|
||||||
|
input: "proxy?no-history&foo=bar",
|
||||||
|
expected: "proxy",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with slash and query param",
|
||||||
|
input: "author/model?no-history",
|
||||||
|
expected: "author/model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Simulate the query parameter stripping logic
|
||||||
|
logMonitorId := tt.input
|
||||||
|
if idx := strings.Index(logMonitorId, "?"); idx != -1 {
|
||||||
|
logMonitorId = logMonitorId[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
if logMonitorId != tt.expected {
|
||||||
|
t.Errorf("Query parameter stripping failed: got %q, want %q", logMonitorId, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyManager_GetLogger_ProcessGroups verifies getLogger resolves the
|
||||||
|
// well-known "proxy"/"upstream" loggers and a model ID managed by processGroups.
|
||||||
|
func TestProxyManager_GetLogger_ProcessGroups(t *testing.T) {
|
||||||
|
cfg := testConfigFromYAML(t, `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
||||||
|
`)
|
||||||
|
pm := New(cfg)
|
||||||
|
defer pm.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
id string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"proxy", false},
|
||||||
|
{"upstream", false},
|
||||||
|
{"model1", false},
|
||||||
|
{"does-not-exist", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.id, func(t *testing.T) {
|
||||||
|
logger, err := pm.getLogger(tt.id)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid logger")
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyManager_GetLogger_Matrix verifies that getLogger can resolve a model
|
||||||
|
// ID when the proxy is configured with a swap matrix (pm.processGroups is empty
|
||||||
|
// for matrix-managed models).
|
||||||
|
func TestProxyManager_GetLogger_Matrix(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
ExpandedSets: []config.ExpandedSet{
|
||||||
|
{SetName: "s1", Models: []string{"model1", "model2"}},
|
||||||
|
},
|
||||||
|
Matrix: &config.MatrixConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := New(cfg)
|
||||||
|
defer pm.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
id string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"proxy", false},
|
||||||
|
{"upstream", false},
|
||||||
|
{"model1", false},
|
||||||
|
{"model2", false},
|
||||||
|
{"does-not-exist", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.id, func(t *testing.T) {
|
||||||
|
logger, err := pm.getLogger(tt.id)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid logger")
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyManager_StreamLogs_Matrix verifies that /logs/stream/<modelID>
|
||||||
|
// returns 200 (not 400) for a model managed by the swap matrix.
|
||||||
|
func TestProxyManager_StreamLogs_Matrix(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"matrix-model": getTestSimpleResponderConfig("matrix-model"),
|
||||||
|
},
|
||||||
|
ExpandedSets: []config.ExpandedSet{
|
||||||
|
{SetName: "s1", Models: []string{"matrix-model"}},
|
||||||
|
},
|
||||||
|
Matrix: &config.MatrixConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := New(cfg)
|
||||||
|
defer pm.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/logs/stream/matrix-model", nil)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
rec := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
pm.ServeHTTP(rec, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
assert.Equal(t, 200, rec.Code)
|
||||||
|
}
|
||||||
@@ -1721,3 +1721,61 @@ models:
|
|||||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_AudioTranscriptionCapture(t *testing.T) {
|
||||||
|
cfg := testConfigFromYAML(t, `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
captureBuffer: 5
|
||||||
|
models:
|
||||||
|
TheExpectedModel:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel
|
||||||
|
`)
|
||||||
|
|
||||||
|
proxy := New(cfg)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
injectTestHandlers(proxy, nil)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&b)
|
||||||
|
|
||||||
|
fw, err := w.CreateFormField("model")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte("TheExpectedModel"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte("test audio content"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
req.Header.Set("Authorization", "Bearer mysecret")
|
||||||
|
req.Header.Set("X-Custom-Req", "req-value")
|
||||||
|
rec := CreateTestResponseRecorder()
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
// Verify capture exists
|
||||||
|
metrics := proxy.metricsMonitor.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.True(t, metrics[0].HasCapture)
|
||||||
|
|
||||||
|
capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID)
|
||||||
|
assert.NotNil(t, capture)
|
||||||
|
|
||||||
|
// Should capture request headers (sensitive ones redacted)
|
||||||
|
assert.NotEmpty(t, capture.ReqHeaders)
|
||||||
|
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||||
|
assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"])
|
||||||
|
|
||||||
|
// Should capture response headers
|
||||||
|
assert.NotNil(t, capture.RespHeaders)
|
||||||
|
|
||||||
|
// Should NOT capture request bodies but get response bodies (text
|
||||||
|
assert.Nil(t, capture.ReqBody)
|
||||||
|
assert.NotNil(t, capture.RespBody)
|
||||||
|
}
|
||||||
|
|||||||
Generated
+3
-3
@@ -2788,9 +2788,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/postcss": {
|
"node_modules/postcss": {
|
||||||
"version": "8.5.8",
|
"version": "8.5.12",
|
||||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz",
|
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
|
||||||
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==",
|
"integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"funding": [
|
"funding": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -9,13 +9,13 @@
|
|||||||
|
|
||||||
let stats = $derived.by(() => {
|
let stats = $derived.by(() => {
|
||||||
const totalRequests = $metrics.length;
|
const totalRequests = $metrics.length;
|
||||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.tokens.input_tokens, 0);
|
||||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.tokens.output_tokens, 0);
|
||||||
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.cache_tokens, 0);
|
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.tokens.cache_tokens, 0);
|
||||||
|
|
||||||
const promptPerSecond = $metrics.filter((m) => m.prompt_per_second > 0).map((m) => m.prompt_per_second);
|
const promptPerSecond = $metrics.filter((m) => m.tokens.prompt_per_second > 0).map((m) => m.tokens.prompt_per_second);
|
||||||
|
|
||||||
const tokensPerSecond = $metrics.filter((m) => m.tokens_per_second > 0).map((m) => m.tokens_per_second);
|
const tokensPerSecond = $metrics.filter((m) => m.tokens.tokens_per_second > 0).map((m) => m.tokens.tokens_per_second);
|
||||||
|
|
||||||
const promptHistogramData =
|
const promptHistogramData =
|
||||||
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
|
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
|
||||||
|
|||||||
@@ -12,15 +12,22 @@ export interface Model {
|
|||||||
aliases?: string[];
|
aliases?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Metrics {
|
export interface TokenMetrics {
|
||||||
id: number;
|
|
||||||
timestamp: string;
|
|
||||||
model: string;
|
|
||||||
cache_tokens: number;
|
cache_tokens: number;
|
||||||
input_tokens: number;
|
input_tokens: number;
|
||||||
output_tokens: number;
|
output_tokens: number;
|
||||||
prompt_per_second: number;
|
prompt_per_second: number;
|
||||||
tokens_per_second: number;
|
tokens_per_second: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ActivityLogEntry {
|
||||||
|
id: number;
|
||||||
|
timestamp: string;
|
||||||
|
model: string;
|
||||||
|
req_path: string;
|
||||||
|
resp_content_type: string;
|
||||||
|
resp_status_code: number;
|
||||||
|
tokens: TokenMetrics;
|
||||||
duration_ms: number;
|
duration_ms: number;
|
||||||
has_capture: boolean;
|
has_capture: boolean;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,87 @@
|
|||||||
import ActivityStats from "../components/ActivityStats.svelte";
|
import ActivityStats from "../components/ActivityStats.svelte";
|
||||||
import Tooltip from "../components/Tooltip.svelte";
|
import Tooltip from "../components/Tooltip.svelte";
|
||||||
import CaptureDialog from "../components/CaptureDialog.svelte";
|
import CaptureDialog from "../components/CaptureDialog.svelte";
|
||||||
|
import { persistentStore } from "../stores/persistent";
|
||||||
|
import { onMount } from "svelte";
|
||||||
import type { ReqRespCapture } from "../lib/types";
|
import type { ReqRespCapture } from "../lib/types";
|
||||||
|
|
||||||
|
type ColumnKey =
|
||||||
|
| "id"
|
||||||
|
| "time"
|
||||||
|
| "model"
|
||||||
|
| "req_path"
|
||||||
|
| "resp_status_code"
|
||||||
|
| "resp_content_type"
|
||||||
|
| "cached"
|
||||||
|
| "prompt"
|
||||||
|
| "generated"
|
||||||
|
| "prompt_speed"
|
||||||
|
| "gen_speed"
|
||||||
|
| "duration"
|
||||||
|
| "capture";
|
||||||
|
|
||||||
|
interface ColumnDef {
|
||||||
|
key: ColumnKey;
|
||||||
|
label: string;
|
||||||
|
defaultVisible: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns: ColumnDef[] = [
|
||||||
|
{ key: "id", label: "ID", defaultVisible: true },
|
||||||
|
{ key: "time", label: "Time", defaultVisible: true },
|
||||||
|
{ key: "model", label: "Model", defaultVisible: true },
|
||||||
|
{ key: "req_path", label: "Path", defaultVisible: false },
|
||||||
|
{ key: "resp_status_code", label: "Status", defaultVisible: false },
|
||||||
|
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
||||||
|
{ key: "cached", label: "Cached", defaultVisible: true },
|
||||||
|
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
||||||
|
{ key: "generated", label: "Generated", defaultVisible: true },
|
||||||
|
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
||||||
|
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
||||||
|
{ key: "duration", label: "Duration", defaultVisible: true },
|
||||||
|
{ key: "capture", label: "Capture", defaultVisible: true },
|
||||||
|
];
|
||||||
|
|
||||||
|
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
|
||||||
|
|
||||||
|
const visibleColumns = persistentStore<ColumnKey[]>(
|
||||||
|
"activity-columns",
|
||||||
|
defaultVisibleKeys
|
||||||
|
);
|
||||||
|
|
||||||
|
let columnsMenuOpen = $state(false);
|
||||||
|
let dropdownContainer: HTMLDivElement | null = null;
|
||||||
|
|
||||||
|
onMount(() => {
|
||||||
|
function handleKeydown(e: KeyboardEvent) {
|
||||||
|
if (e.key === "Escape" && columnsMenuOpen) {
|
||||||
|
columnsMenuOpen = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function handleClick(e: MouseEvent) {
|
||||||
|
if (columnsMenuOpen && dropdownContainer && !dropdownContainer.contains(e.target as Node)) {
|
||||||
|
columnsMenuOpen = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
document.addEventListener("keydown", handleKeydown);
|
||||||
|
document.addEventListener("click", handleClick);
|
||||||
|
return () => {
|
||||||
|
document.removeEventListener("keydown", handleKeydown);
|
||||||
|
document.removeEventListener("click", handleClick);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
function toggleColumn(key: ColumnKey) {
|
||||||
|
const current = $visibleColumns;
|
||||||
|
if (current.includes(key)) {
|
||||||
|
if (current.length > 1) {
|
||||||
|
visibleColumns.set(current.filter((k) => k !== key));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
visibleColumns.set([...current, key]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function formatSpeed(speed: number): string {
|
function formatSpeed(speed: number): string {
|
||||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||||
}
|
}
|
||||||
@@ -67,58 +146,150 @@
|
|||||||
<ActivityStats />
|
<ActivityStats />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="card overflow-auto">
|
<div class="card overflow-auto relative min-h-[30rem]">
|
||||||
|
<div class="flex justify-end px-4" bind:this={dropdownContainer}>
|
||||||
|
<div class="relative">
|
||||||
|
<button
|
||||||
|
class="w-8 h-8 flex items-center justify-center rounded hover:bg-secondary-hover transition-colors"
|
||||||
|
onclick={() => (columnsMenuOpen = !columnsMenuOpen)}
|
||||||
|
title="Select columns"
|
||||||
|
>
|
||||||
|
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6V4m0 2a2 2 0 100 4m0-4a2 2 0 110 4m-6 8a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4m6 6v10m6-2a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4"></path>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
{#if columnsMenuOpen}
|
||||||
|
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]">
|
||||||
|
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10">
|
||||||
|
Columns
|
||||||
|
</div>
|
||||||
|
{#each columns as col (col.key)}
|
||||||
|
<label
|
||||||
|
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={$visibleColumns.includes(col.key)}
|
||||||
|
onchange={() => toggleColumn(col.key)}
|
||||||
|
class="rounded"
|
||||||
|
/>
|
||||||
|
{col.label}
|
||||||
|
</label>
|
||||||
|
{/each}
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<table class="min-w-full divide-y">
|
<table class="min-w-full divide-y">
|
||||||
<thead class="border-gray-200 dark:border-white/10">
|
<thead class="border-gray-200 dark:border-white/10">
|
||||||
<tr class="text-left text-xs uppercase tracking-wider">
|
<tr class="text-left text-xs uppercase tracking-wider">
|
||||||
<th class="px-6 py-3">ID</th>
|
{#if $visibleColumns.includes("id")}
|
||||||
<th class="px-6 py-3">Time</th>
|
<th class="px-6 py-3">ID</th>
|
||||||
<th class="px-6 py-3">Model</th>
|
{/if}
|
||||||
<th class="px-6 py-3">
|
{#if $visibleColumns.includes("time")}
|
||||||
Cached <Tooltip content="prompt tokens from cache" />
|
<th class="px-6 py-3">Time</th>
|
||||||
</th>
|
{/if}
|
||||||
<th class="px-6 py-3">
|
{#if $visibleColumns.includes("model")}
|
||||||
Prompt <Tooltip content="new prompt tokens processed" />
|
<th class="px-6 py-3">Model</th>
|
||||||
</th>
|
{/if}
|
||||||
<th class="px-6 py-3">Generated</th>
|
{#if $visibleColumns.includes("req_path")}
|
||||||
<th class="px-6 py-3">Prompt Processing</th>
|
<th class="px-6 py-3">Path</th>
|
||||||
<th class="px-6 py-3">Generation Speed</th>
|
{/if}
|
||||||
<th class="px-6 py-3">Duration</th>
|
{#if $visibleColumns.includes("resp_status_code")}
|
||||||
<th class="px-6 py-3">Capture</th>
|
<th class="px-6 py-3">Status</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("resp_content_type")}
|
||||||
|
<th class="px-6 py-3">Content-Type</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("cached")}
|
||||||
|
<th class="px-6 py-3">
|
||||||
|
Cached <Tooltip content="prompt tokens from cache" />
|
||||||
|
</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("prompt")}
|
||||||
|
<th class="px-6 py-3">
|
||||||
|
Prompt <Tooltip content="new prompt tokens processed" />
|
||||||
|
</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("generated")}
|
||||||
|
<th class="px-6 py-3">Generated</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("prompt_speed")}
|
||||||
|
<th class="px-6 py-3">Prompt Speed</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("gen_speed")}
|
||||||
|
<th class="px-6 py-3">Gen Speed</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("duration")}
|
||||||
|
<th class="px-6 py-3">Duration</th>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("capture")}
|
||||||
|
<th class="px-6 py-3">Capture</th>
|
||||||
|
{/if}
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody class="divide-y">
|
<tbody class="divide-y">
|
||||||
{#if sortedMetrics.length === 0}
|
{#if sortedMetrics.length === 0}
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="10" class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||||
No activity recorded
|
No activity recorded
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{:else}
|
{:else}
|
||||||
{#each sortedMetrics as metric (metric.id)}
|
{#each sortedMetrics as metric (metric.id)}
|
||||||
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||||
<td class="px-4 py-4">{metric.id + 1}</td>
|
{#if $visibleColumns.includes("id")}
|
||||||
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
<td class="px-4 py-4">{metric.id + 1}</td>
|
||||||
<td class="px-6 py-4">{metric.model}</td>
|
{/if}
|
||||||
<td class="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
|
{#if $visibleColumns.includes("time")}
|
||||||
<td class="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
|
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
||||||
<td class="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
|
{/if}
|
||||||
<td class="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
|
{#if $visibleColumns.includes("model")}
|
||||||
<td class="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
|
<td class="px-6 py-4">{metric.model}</td>
|
||||||
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
{/if}
|
||||||
<td class="px-6 py-4">
|
{#if $visibleColumns.includes("req_path")}
|
||||||
{#if metric.has_capture}
|
<td class="px-6 py-4">{metric.req_path || "-"}</td>
|
||||||
<button
|
{/if}
|
||||||
onclick={() => viewCapture(metric.id)}
|
{#if $visibleColumns.includes("resp_status_code")}
|
||||||
disabled={loadingCaptureId === metric.id}
|
<td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
|
||||||
class="btn btn--sm"
|
{/if}
|
||||||
>
|
{#if $visibleColumns.includes("resp_content_type")}
|
||||||
{loadingCaptureId === metric.id ? "..." : "View"}
|
<td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
|
||||||
</button>
|
{/if}
|
||||||
{:else}
|
{#if $visibleColumns.includes("cached")}
|
||||||
<span class="text-txtsecondary">-</span>
|
<td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
|
||||||
{/if}
|
{/if}
|
||||||
</td>
|
{#if $visibleColumns.includes("prompt")}
|
||||||
|
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("generated")}
|
||||||
|
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("prompt_speed")}
|
||||||
|
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("gen_speed")}
|
||||||
|
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("duration")}
|
||||||
|
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
||||||
|
{/if}
|
||||||
|
{#if $visibleColumns.includes("capture")}
|
||||||
|
<td class="px-6 py-4">
|
||||||
|
{#if metric.has_capture}
|
||||||
|
<button
|
||||||
|
onclick={() => viewCapture(metric.id)}
|
||||||
|
disabled={loadingCaptureId === metric.id}
|
||||||
|
class="btn btn--sm"
|
||||||
|
>
|
||||||
|
{loadingCaptureId === metric.id ? "..." : "View"}
|
||||||
|
</button>
|
||||||
|
{:else}
|
||||||
|
<span class="text-txtsecondary">-</span>
|
||||||
|
{/if}
|
||||||
|
</td>
|
||||||
|
{/if}
|
||||||
</tr>
|
</tr>
|
||||||
{/each}
|
{/each}
|
||||||
{/if}
|
{/if}
|
||||||
|
|||||||
@@ -10,49 +10,35 @@
|
|||||||
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
|
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
|
||||||
|
|
||||||
let direction = $derived<"horizontal" | "vertical">(
|
let direction = $derived<"horizontal" | "vertical">(
|
||||||
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal"
|
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal",
|
||||||
);
|
);
|
||||||
|
|
||||||
function cycleViewMode(): void {
|
|
||||||
const modes: ViewMode[] = ["panels", "proxy", "upstream"];
|
|
||||||
const currentIndex = modes.indexOf($viewModeStore);
|
|
||||||
const nextIndex = (currentIndex + 1) % modes.length;
|
|
||||||
viewModeStore.set(modes[nextIndex]);
|
|
||||||
}
|
|
||||||
|
|
||||||
function getViewModeIcon(mode: ViewMode): string {
|
|
||||||
switch (mode) {
|
|
||||||
case "proxy":
|
|
||||||
return "P";
|
|
||||||
case "upstream":
|
|
||||||
return "U";
|
|
||||||
case "panels":
|
|
||||||
return "⊞";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function getViewModeLabel(mode: ViewMode): string {
|
|
||||||
switch (mode) {
|
|
||||||
case "proxy":
|
|
||||||
return "Proxy";
|
|
||||||
case "upstream":
|
|
||||||
return "Upstream";
|
|
||||||
case "panels":
|
|
||||||
return "Panels";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="flex flex-col h-full w-full gap-2">
|
<div class="flex flex-col h-full w-full gap-2">
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-1">
|
||||||
<button
|
<button
|
||||||
onclick={cycleViewMode}
|
onclick={() => viewModeStore.set("panels")}
|
||||||
class="btn flex items-center gap-2 text-sm"
|
class:btn={true}
|
||||||
title="Toggle view mode"
|
class:bg-primary={$viewModeStore === "panels"}
|
||||||
aria-label="Toggle view mode: {getViewModeLabel($viewModeStore)}"
|
class:text-btn-primary-text={$viewModeStore === "panels"}
|
||||||
>
|
>
|
||||||
<span class="font-mono font-bold">{getViewModeIcon($viewModeStore)}</span>
|
Both
|
||||||
<span>{getViewModeLabel($viewModeStore)}</span>
|
</button>
|
||||||
|
<button
|
||||||
|
onclick={() => viewModeStore.set("proxy")}
|
||||||
|
class:btn={true}
|
||||||
|
class:bg-primary={$viewModeStore === "proxy"}
|
||||||
|
class:text-btn-primary-text={$viewModeStore === "proxy"}
|
||||||
|
>
|
||||||
|
Proxy
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onclick={() => viewModeStore.set("upstream")}
|
||||||
|
class:btn={true}
|
||||||
|
class:bg-primary={$viewModeStore === "upstream"}
|
||||||
|
class:text-btn-primary-text={$viewModeStore === "upstream"}
|
||||||
|
>
|
||||||
|
Upstream
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,13 @@
|
|||||||
import { writable } from "svelte/store";
|
import { writable } from "svelte/store";
|
||||||
import type { Model, Metrics, VersionInfo, LogData, APIEventEnvelope, ReqRespCapture, InFlightStats } from "../lib/types";
|
import type {
|
||||||
|
Model,
|
||||||
|
ActivityLogEntry,
|
||||||
|
VersionInfo,
|
||||||
|
LogData,
|
||||||
|
APIEventEnvelope,
|
||||||
|
ReqRespCapture,
|
||||||
|
InFlightStats,
|
||||||
|
} from "../lib/types";
|
||||||
import { connectionState } from "./theme";
|
import { connectionState } from "./theme";
|
||||||
|
|
||||||
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||||
@@ -8,7 +16,7 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
|||||||
export const models = writable<Model[]>([]);
|
export const models = writable<Model[]>([]);
|
||||||
export const proxyLogs = writable<string>("");
|
export const proxyLogs = writable<string>("");
|
||||||
export const upstreamLogs = writable<string>("");
|
export const upstreamLogs = writable<string>("");
|
||||||
export const metrics = writable<Metrics[]>([]);
|
export const metrics = writable<ActivityLogEntry[]>([]);
|
||||||
export const inFlightRequests = writable<number>(0);
|
export const inFlightRequests = writable<number>(0);
|
||||||
export const versionInfo = writable<VersionInfo>({
|
export const versionInfo = writable<VersionInfo>({
|
||||||
build_date: "unknown",
|
build_date: "unknown",
|
||||||
@@ -62,7 +70,7 @@ export function enableAPIEvents(enabled: boolean): void {
|
|||||||
const newModels = JSON.parse(message.data) as Model[];
|
const newModels = JSON.parse(message.data) as Model[];
|
||||||
// Sort models by name and id
|
// Sort models by name and id
|
||||||
newModels.sort((a, b) => {
|
newModels.sort((a, b) => {
|
||||||
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric : true} );
|
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric: true });
|
||||||
});
|
});
|
||||||
models.set(newModels);
|
models.set(newModels);
|
||||||
break;
|
break;
|
||||||
@@ -82,7 +90,7 @@ export function enableAPIEvents(enabled: boolean): void {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "metrics": {
|
case "metrics": {
|
||||||
const newMetrics = JSON.parse(message.data) as Metrics[];
|
const newMetrics = JSON.parse(message.data) as ActivityLogEntry[];
|
||||||
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
|
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import { persistentStore } from "./persistent";
|
|||||||
import type { ScreenWidth } from "../lib/types";
|
import type { ScreenWidth } from "../lib/types";
|
||||||
|
|
||||||
// Persistent stores
|
// Persistent stores
|
||||||
export const isDarkMode = persistentStore<boolean>("theme", false);
|
const systemDark = typeof window !== "undefined" && window.matchMedia("(prefers-color-scheme: dark)").matches;
|
||||||
|
export const isDarkMode = persistentStore<boolean>("theme", systemDark);
|
||||||
export const appTitle = persistentStore<string>("app-title", "llama-swap");
|
export const appTitle = persistentStore<string>("app-title", "llama-swap");
|
||||||
|
|
||||||
// Non-persistent stores
|
// Non-persistent stores
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ export default defineConfig({
|
|||||||
assetsDir: "assets",
|
assetsDir: "assets",
|
||||||
},
|
},
|
||||||
server: {
|
server: {
|
||||||
|
// yes very insecure but who's running this thing
|
||||||
|
// on the public internet for dev?! haha.
|
||||||
|
host: "0.0.0.0",
|
||||||
|
allowedHosts: true,
|
||||||
proxy: {
|
proxy: {
|
||||||
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||||
"/logs": "http://localhost:8080",
|
"/logs": "http://localhost:8080",
|
||||||
|
|||||||
Reference in New Issue
Block a user