Compare commits

...

3 Commits

Author SHA1 Message Date
Benson Wong f852689104 proxy: add panic recovery to Process.ProxyRequest (#363)
Switching to use httputil.ReverseProxy in #342 introduced a possible
panic if a client disconnects while streaming the body. Since llama-swap
does not use http.Server the recover() is not automatically there.

- introduce a recover() in Process.ProxyRequest to recover and log the
  event
- add TestProcess_ReverseProxyPanicIsHandled to reproduce and test the
  fix

fixes: #362
2025-10-25 20:40:05 -07:00
Benson Wong e250e71e59 Include metrics from upstream chat requests (#361)
* proxy: refactor metrics recording

- remove metrics_middleware.go as this wrapper is no longer needed. This
  also eliminiates double body parsing for the modelID
- move metrics parsing to be part of MetricsMonitor
- refactor how metrics are recording in ProxyManager
- add MetricsMonitor tests
- improve mem efficiency of processStreamingResponse
- add benchmarks for MetricsMonitor.addMetrics
- proxy: refactor MetricsMonitor to be more safe handling errors
2025-10-25 17:38:18 -07:00
Benson Wong d18dc26d01 cmd/wol-proxy: tweak logs to show what is causing wake ups (#356)
fix the extra wake ups being caused by wol-proxy

* cmd/wol-proxy: tweak logs to show what is causing wake ups
* cmd/wol-proxy: add skip wakeup
* cmd/wol-proxy: replace ticker with SSE connection
* cmd/wol-proxy: increase scanner buffer size
* cmd/wol-proxy: improve failure tracking
2025-10-25 11:04:31 -07:00
10 changed files with 1127 additions and 325 deletions
+1 -1
View File
@@ -90,7 +90,7 @@ GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH) cmd/wol-proxy/wol-proxy.go
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
# Phony targets
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
+104 -35
View File
@@ -1,6 +1,7 @@
package main
import (
"bufio"
"context"
"errors"
"flag"
@@ -13,6 +14,7 @@ import (
"net/url"
"os"
"os/signal"
"strings"
"sync"
"time"
)
@@ -92,14 +94,9 @@ func main() {
}()
// graceful shutdown
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
slog.Error("server shutdown error", "error", err)
}
server.Close()
}
type upstreamStatus string
@@ -124,37 +121,86 @@ func newProxy(url *url.URL) *proxyServer {
failCount: 0,
}
// start a goroutien to check upstream status
// start a goroutine to monitor upstream status via SSE
go func() {
checkUrl := url.Scheme + "://" + url.Host + "/wol-health"
client := &http.Client{Timeout: time.Second}
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for range ticker.C {
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
client := &http.Client{
Timeout: 0, // No timeout for SSE connection
}
slog.Debug("checking upstream status at", "url", checkUrl)
resp, err := client.Get(checkUrl)
waitDuration := 10 * time.Second
// drain the body
if err == nil && resp != nil {
for {
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
req, err := http.NewRequest("GET", eventsUrl, nil)
if err != nil {
slog.Warn("failed to create SSE request", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(waitDuration)
continue
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
resp, err := client.Do(req)
if err != nil {
slog.Error("failed to connect to SSE endpoint", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
if resp.StatusCode != http.StatusOK {
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
}
if err == nil && resp != nil && resp.StatusCode == http.StatusOK {
slog.Debug("upstream status: ready")
proxy.setStatus(ready)
proxy.statusMutex.Lock()
proxy.failCount = 0
proxy.statusMutex.Unlock()
} else {
slog.Debug("upstream status: notready", "error", err)
proxy.setStatus(notready)
proxy.statusMutex.Lock()
proxy.failCount++
proxy.statusMutex.Unlock()
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
// Successfully connected to SSE endpoint
slog.Info("connected to SSE endpoint, upstream ready")
proxy.setStatus(ready)
proxy.resetFailures()
// Read from the SSE stream to detect disconnection
scanner := bufio.NewScanner(resp.Body)
// use a fairly large buffer to avoid scanner errors when reading large SSE events
buf := make([]byte, 0, 1024*1024*2)
scanner.Buffer(buf, 1024*1024*2)
events := 0
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
fmt.Print("Events: ")
}
for scanner.Scan() {
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
// Just read the events to keep connection alive
// We don't need to process the event data
events++
fmt.Printf("%d, ", events)
}
}
fmt.Println()
if err := scanner.Err(); err != nil {
slog.Error("error reading from SSE stream", "error", err)
}
// Connection closed or error occurred
_ = resp.Body.Close()
slog.Info("SSE connection closed, upstream not ready")
proxy.setStatus(notready)
proxy.incFail(1)
// Wait before reconnecting
time.Sleep(waitDuration)
}
}()
@@ -163,10 +209,8 @@ func newProxy(url *url.URL) *proxyServer {
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && r.URL.Path == "/status" {
p.statusMutex.RLock()
status := string(p.status)
failCount := p.failCount
p.statusMutex.RUnlock()
status := string(p.getStatus())
failCount := p.getFailures()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
fmt.Fprintf(w, "status: %s\n", status)
@@ -175,7 +219,14 @@ func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if p.getStatus() == notready {
slog.Info("upstream not ready, sending magic packet", "mac", *flagMac)
path := r.URL.Path
if strings.HasPrefix(path, "/api/events") {
slog.Debug("Skipping wake up", "req", path)
w.WriteHeader(http.StatusNoContent)
return
}
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
if err := sendMagicPacket(*flagMac); err != nil {
slog.Warn("failed to send magic WoL packet", "error", err)
}
@@ -213,6 +264,24 @@ func (p *proxyServer) setStatus(status upstreamStatus) {
p.status = status
}
func (p *proxyServer) incFail(num int) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount += num
}
func (p *proxyServer) getFailures() int {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.failCount
}
func (p *proxyServer) resetFailures() {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount = 0
}
func sendMagicPacket(macAddr string) error {
hwAddr, err := net.ParseMAC(macAddr)
if err != nil {
-184
View File
@@ -1,184 +0,0 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
// isStreaming bool
startTime time.Time
}
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
// check for streaming response
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
writer.metricsRecorder.processStreamingResponse(writer.body)
} else {
writer.metricsRecorder.processNonStreamingResponse(writer.body)
}
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return false
}
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
+198 -15
View File
@@ -1,12 +1,18 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/tidwall/gjson"
)
// TokenMetrics represents parsed token statistics from llama-server logs
@@ -31,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
}
// MetricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct {
// metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
maxMetrics int
nextID int
logger *LogMonitor
}
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
}
mp := &MetricsMonitor{
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
mp := &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
}
@@ -53,7 +56,7 @@ func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
mp.mu.Lock()
defer mp.mu.Unlock()
@@ -66,8 +69,8 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
event.Emit(TokenMetricsEvent{Metrics: metric})
}
// GetMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
// getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
mp.mu.RLock()
defer mp.mu.RUnlock()
@@ -76,9 +79,189 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
return result
}
// GetMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
// getMetricsJSON returns metrics as JSON
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
}
// wrapHandler wraps the proxy handler to extract token metrics
// if wrapHandler returns an error it is safe to assume that no
// data was sent to the client
func (mp *metricsMonitor) wrapHandler(
modelID string,
writer gin.ResponseWriter,
request *http.Request,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
recorder := newBodyCopier(writer)
if err := next(modelID, recorder, request); err != nil {
return err
}
// after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
return nil
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics skipped, empty body")
return nil
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
}
} else {
if gjson.ValidBytes(body) {
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
}
} else {
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
}
}
return nil
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.
// Start from the end of the body and scan backwards for newlines
pos := len(body)
for pos > 0 {
// Find the previous newline (or start of body)
lineStart := bytes.LastIndexByte(body[:pos], '\n')
if lineStart == -1 {
lineStart = 0
} else {
lineStart++ // Move past the newline
}
line := bytes.TrimSpace(body[lineStart:pos])
pos = lineStart - 1 // Move position before the newline for next iteration
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
return parseMetrics(modelID, start, gjson.ParseBytes(data))
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
}
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(start).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
return TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
}, nil
}
// responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
gin.ResponseWriter
body *bytes.Buffer
tee io.Writer
start time.Time
}
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
bodyBuffer := &bytes.Buffer{}
return &responseBodyCopier{
ResponseWriter: w,
body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer),
}
}
func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
}
// Single write operation that writes to both the response and buffer
return w.tee.Write(b)
}
func (w *responseBodyCopier) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseBodyCopier) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *responseBodyCopier) StartTime() time.Time {
return w.start
}
+693
View File
@@ -0,0 +1,693 @@
package proxy
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/stretchr/testify/assert"
)
func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"})
}
metrics := mm.getMetrics()
assert.Equal(t, 5, len(metrics))
for i := 0; i < 5; i++ {
assert.Equal(t, i, metrics[i].ID)
}
})
t.Run("respects max metrics limit", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 3)
// Add 5 metrics
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{
Model: "model",
InputTokens: i,
})
}
metrics := mm.getMetrics()
assert.Equal(t, 3, len(metrics))
// Should keep the last 3 metrics (IDs 2, 3, 4)
assert.Equal(t, 2, metrics[0].ID)
assert.Equal(t, 3, metrics[1].ID)
assert.Equal(t, 4, metrics[2].ID)
})
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
receivedEvent := make(chan TokenMetricsEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) {
receivedEvent <- e
})
defer cancel()
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
select {
case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event")
}
})
}
func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns empty slice when no metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metrics := mm.getMetrics()
assert.NotNil(t, metrics)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"})
metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics()
// Verify we got copies
assert.Equal(t, 2, len(metrics1))
assert.Equal(t, 2, len(metrics2))
// Modify the returned slice shouldn't affect the original
metrics1[0].Model = "modified"
metrics3 := mm.getMetrics()
assert.Equal(t, "model1", metrics3[0].Model)
})
}
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
assert.NotNil(t, jsonData)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{
Model: "model1",
InputTokens: 100,
OutputTokens: 50,
TokensPerSecond: 25.5,
})
mm.addMetrics(TokenMetrics{
Model: "model2",
InputTokens: 200,
OutputTokens: 100,
TokensPerSecond: 30.0,
})
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 2, len(metrics))
assert.Equal(t, "model1", metrics[0].Model)
assert.Equal(t, "model2", metrics[1].Model)
})
}
func TestMetricsMonitor_WrapHandler(t *testing.T) {
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("successful request with timings data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0,
"cache_n": 20
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
})
t.Run("streaming request with SSE format", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Note: SSE format requires proper line breaks - each data line followed by blank line
responseBody := `data: {"choices":[{"text":"Hello"}]}
data: {"choices":[{"text":" World"}]}
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens)
})
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("empty response body does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("not valid json"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("next handler error is propagated", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
expectedErr := assert.AnError
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
return expectedErr
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"result": "ok"}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
}
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
t.Run("captures response body", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
testData := []byte("test response body")
n, err := copier.Write(testData)
assert.NoError(t, err)
assert.Equal(t, len(testData), n)
assert.Equal(t, testData, copier.body.Bytes())
assert.Equal(t, string(testData), rec.Body.String())
})
t.Run("sets start time on first write", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
assert.True(t, copier.StartTime().IsZero())
copier.Write([]byte("test"))
assert.False(t, copier.StartTime().IsZero())
})
t.Run("preserves headers", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.Header().Set("X-Test", "value")
assert.Equal(t, "value", rec.Header().Get("X-Test"))
})
t.Run("preserves status code", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.WriteHeader(http.StatusCreated)
// Gin's ResponseWriter tracks status internally
assert.Equal(t, http.StatusCreated, copier.Status())
})
}
func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000)
var wg sync.WaitGroup
numGoroutines := 10
metricsPerGoroutine := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{
Model: "test-model",
InputTokens: id*1000 + j,
OutputTokens: j,
})
}
}(i)
}
wg.Wait()
metrics := mm.getMetrics()
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
})
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 100)
done := make(chan bool)
// Writer goroutine
go func() {
for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"})
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
// Multiple reader goroutines
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 20; j++ {
_ = mm.getMetrics()
_, _ = mm.getMetricsJSON()
time.Sleep(2 * time.Millisecond)
}
}()
}
<-done
wg.Wait()
// Final check
metrics := mm.getMetrics()
assert.Equal(t, 50, len(metrics))
})
}
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
t.Run("prefers timings over usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Timings should take precedence over usage
responseBody := `{
"usage": {
"prompt_tokens": 50,
"completion_tokens": 25
},
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles missing cache_n in timings", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
})
}
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Metrics should be found in the last data line before [DONE]
responseBody := `data: {"choices":[{"text":"First"}]}
data: {"choices":[{"text":"Second"}]}
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `data: not json
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("handles empty streaming response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := ``
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
// Empty body should not trigger WrapHandler processing
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
}
// Benchmark tests
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
+12
View File
@@ -499,6 +499,18 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
startDuration = time.Since(beginStartTime)
}
// recover from http.ErrAbortHandler panics that can occur when the client
// disconnects before the response is sent
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
} else {
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
}
}
}()
if p.reverseProxy != nil {
p.reverseProxy.ServeHTTP(w, r)
} else {
+71
View File
@@ -494,3 +494,74 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
}
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
// handled appropriately.
//
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
// can't copy the body. This can be caused by a client disconnecting before the full
// response is sent from some reason.
//
// bug: https://github.com/mostlygeek/llama-swap/issues/362
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
defer func() {
if r := recover(); r != nil {
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
}
}()
expectedMessage := "panic_test"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
defer process.Stop()
// Start the process
err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())
// Create a custom ResponseWriter that simulates a client disconnect
// by panicking when Write is called after headers are sent
panicWriter := &panicOnWriteResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
shouldPanic: true,
}
// Make a request that will trigger the panic
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
// ProxyRequest should catch and handle this panic gracefully.
process.ProxyRequest(panicWriter, req)
// If we get here, the panic was properly recovered in ProxyRequest
// The process should still be in a ready state
assert.Equal(t, StateReady, process.CurrentState())
}
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
// to simulate a client disconnect after headers are sent
// used by: TestProcess_ReverseProxyPanicIsHandled
type panicOnWriteResponseWriter struct {
*httptest.ResponseRecorder
shouldPanic bool
headerWritten bool
}
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
w.headerWritten = true
w.ResponseRecorder.WriteHeader(statusCode)
}
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
if w.shouldPanic && w.headerWritten {
// Simulate the panic that httputil.ReverseProxy throws
panic(http.ErrAbortHandler)
}
return w.ResponseRecorder.Write(b)
}
+46 -18
View File
@@ -36,7 +36,7 @@ type ProxyManager struct {
upstreamLogger *LogMonitor
muxLogger *LogMonitor
metricsMonitor *MetricsMonitor
metricsMonitor *metricsMonitor
processGroups map[string]*ProcessGroup
@@ -75,6 +75,13 @@ func New(config config.Config) *ProxyManager {
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int
if config.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback
} else {
maxMetrics = config.MetricsMaxInMemory
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
@@ -83,7 +90,7 @@ func New(config config.Config) *ProxyManager {
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
metricsMonitor: NewMetricsMonitor(&config),
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
processGroups: make(map[string]*ProcessGroup),
@@ -193,27 +200,25 @@ func (pm *ProxyManager) setupGinEngine() {
c.Next()
})
mm := MetricsMiddleware(pm)
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
@@ -474,8 +479,23 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
}
// rewrite the path
originalPath := c.Request.URL.Path
c.Request.URL.Path = remainingPath
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
return
}
}
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
@@ -535,10 +555,18 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
}
+2 -2
View File
@@ -180,7 +180,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
sendMetrics(pm.metricsMonitor.GetMetrics())
sendMetrics(pm.metricsMonitor.getMetrics())
for {
select {
@@ -198,7 +198,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
jsonData, err := pm.metricsMonitor.getMetricsJSON()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return
-70
View File
@@ -911,76 +911,6 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
// t.Logf("%v", response)
}
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a non-streaming request
reqBody := `{"model":"model1", "stream": false}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a streaming request
reqBody := `{"model":"model1", "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_HealthEndpoint(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,