Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f852689104 | |||
| e250e71e59 | |||
| d18dc26d01 | |||
| 8357714421 |
@@ -90,7 +90,7 @@ GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
|
|||||||
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
|
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
|
||||||
wol-proxy: $(BUILD_DIR)
|
wol-proxy: $(BUILD_DIR)
|
||||||
@echo "Building wol-proxy"
|
@echo "Building wol-proxy"
|
||||||
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH) cmd/wol-proxy/wol-proxy.go
|
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
||||||
|
|||||||
+104
-35
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -92,14 +94,9 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// graceful shutdown
|
// graceful shutdown
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
|
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||||
defer stop()
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
server.Close()
|
||||||
defer cancel()
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
|
||||||
slog.Error("server shutdown error", "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type upstreamStatus string
|
type upstreamStatus string
|
||||||
@@ -124,37 +121,86 @@ func newProxy(url *url.URL) *proxyServer {
|
|||||||
failCount: 0,
|
failCount: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
// start a goroutien to check upstream status
|
// start a goroutine to monitor upstream status via SSE
|
||||||
go func() {
|
go func() {
|
||||||
checkUrl := url.Scheme + "://" + url.Host + "/wol-health"
|
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
|
||||||
client := &http.Client{Timeout: time.Second}
|
client := &http.Client{
|
||||||
ticker := time.NewTicker(2 * time.Second)
|
Timeout: 0, // No timeout for SSE connection
|
||||||
defer ticker.Stop()
|
}
|
||||||
for range ticker.C {
|
|
||||||
|
|
||||||
slog.Debug("checking upstream status at", "url", checkUrl)
|
waitDuration := 10 * time.Second
|
||||||
resp, err := client.Get(checkUrl)
|
|
||||||
|
|
||||||
// drain the body
|
for {
|
||||||
if err == nil && resp != nil {
|
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", eventsUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to create SSE request", "error", err)
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
time.Sleep(waitDuration)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
req.Header.Set("Cache-Control", "no-cache")
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to connect to SSE endpoint", "error", err)
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
|
||||||
_, _ = io.Copy(io.Discard, resp.Body)
|
_, _ = io.Copy(io.Discard, resp.Body)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil && resp != nil && resp.StatusCode == http.StatusOK {
|
|
||||||
slog.Debug("upstream status: ready")
|
|
||||||
proxy.setStatus(ready)
|
|
||||||
proxy.statusMutex.Lock()
|
|
||||||
proxy.failCount = 0
|
|
||||||
proxy.statusMutex.Unlock()
|
|
||||||
} else {
|
|
||||||
slog.Debug("upstream status: notready", "error", err)
|
|
||||||
proxy.setStatus(notready)
|
proxy.setStatus(notready)
|
||||||
proxy.statusMutex.Lock()
|
proxy.incFail(1)
|
||||||
proxy.failCount++
|
time.Sleep(10 * time.Second)
|
||||||
proxy.statusMutex.Unlock()
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Successfully connected to SSE endpoint
|
||||||
|
slog.Info("connected to SSE endpoint, upstream ready")
|
||||||
|
proxy.setStatus(ready)
|
||||||
|
proxy.resetFailures()
|
||||||
|
|
||||||
|
// Read from the SSE stream to detect disconnection
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|
||||||
|
// use a fairly large buffer to avoid scanner errors when reading large SSE events
|
||||||
|
buf := make([]byte, 0, 1024*1024*2)
|
||||||
|
scanner.Buffer(buf, 1024*1024*2)
|
||||||
|
events := 0
|
||||||
|
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
fmt.Print("Events: ")
|
||||||
|
}
|
||||||
|
for scanner.Scan() {
|
||||||
|
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
// Just read the events to keep connection alive
|
||||||
|
// We don't need to process the event data
|
||||||
|
events++
|
||||||
|
fmt.Printf("%d, ", events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
slog.Error("error reading from SSE stream", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection closed or error occurred
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
slog.Info("SSE connection closed, upstream not ready")
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
|
||||||
|
// Wait before reconnecting
|
||||||
|
time.Sleep(waitDuration)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -163,10 +209,8 @@ func newProxy(url *url.URL) *proxyServer {
|
|||||||
|
|
||||||
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "GET" && r.URL.Path == "/status" {
|
if r.Method == "GET" && r.URL.Path == "/status" {
|
||||||
p.statusMutex.RLock()
|
status := string(p.getStatus())
|
||||||
status := string(p.status)
|
failCount := p.getFailures()
|
||||||
failCount := p.failCount
|
|
||||||
p.statusMutex.RUnlock()
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
fmt.Fprintf(w, "status: %s\n", status)
|
fmt.Fprintf(w, "status: %s\n", status)
|
||||||
@@ -175,7 +219,14 @@ func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if p.getStatus() == notready {
|
if p.getStatus() == notready {
|
||||||
slog.Info("upstream not ready, sending magic packet", "mac", *flagMac)
|
path := r.URL.Path
|
||||||
|
if strings.HasPrefix(path, "/api/events") {
|
||||||
|
slog.Debug("Skipping wake up", "req", path)
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
|
||||||
if err := sendMagicPacket(*flagMac); err != nil {
|
if err := sendMagicPacket(*flagMac); err != nil {
|
||||||
slog.Warn("failed to send magic WoL packet", "error", err)
|
slog.Warn("failed to send magic WoL packet", "error", err)
|
||||||
}
|
}
|
||||||
@@ -213,6 +264,24 @@ func (p *proxyServer) setStatus(status upstreamStatus) {
|
|||||||
p.status = status
|
p.status = status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) incFail(num int) {
|
||||||
|
p.statusMutex.Lock()
|
||||||
|
defer p.statusMutex.Unlock()
|
||||||
|
p.failCount += num
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) getFailures() int {
|
||||||
|
p.statusMutex.RLock()
|
||||||
|
defer p.statusMutex.RUnlock()
|
||||||
|
return p.failCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) resetFailures() {
|
||||||
|
p.statusMutex.Lock()
|
||||||
|
defer p.statusMutex.Unlock()
|
||||||
|
p.failCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
func sendMagicPacket(macAddr string) error {
|
func sendMagicPacket(macAddr string) error {
|
||||||
hwAddr, err := net.ParseMAC(macAddr)
|
hwAddr, err := net.ParseMAC(macAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,184 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MetricsRecorder struct {
|
|
||||||
metricsMonitor *MetricsMonitor
|
|
||||||
realModelName string
|
|
||||||
// isStreaming bool
|
|
||||||
startTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
|
||||||
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
|
|
||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
||||||
if requestedModel == "" {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
||||||
if !found {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := &MetricsResponseWriter{
|
|
||||||
ResponseWriter: c.Writer,
|
|
||||||
metricsRecorder: &MetricsRecorder{
|
|
||||||
metricsMonitor: pm.metricsMonitor,
|
|
||||||
realModelName: realModelName,
|
|
||||||
startTime: time.Now(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
c.Writer = writer
|
|
||||||
c.Next()
|
|
||||||
|
|
||||||
// check for streaming response
|
|
||||||
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
|
|
||||||
writer.metricsRecorder.processStreamingResponse(writer.body)
|
|
||||||
} else {
|
|
||||||
writer.metricsRecorder.processNonStreamingResponse(writer.body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
|
||||||
usage := jsonData.Get("usage")
|
|
||||||
timings := jsonData.Get("timings")
|
|
||||||
if !usage.Exists() && !timings.Exists() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// default values
|
|
||||||
cachedTokens := -1 // unknown or missing data
|
|
||||||
outputTokens := 0
|
|
||||||
inputTokens := 0
|
|
||||||
|
|
||||||
// timings data
|
|
||||||
tokensPerSecond := -1.0
|
|
||||||
promptPerSecond := -1.0
|
|
||||||
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
|
||||||
|
|
||||||
if usage.Exists() {
|
|
||||||
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
|
||||||
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
|
||||||
}
|
|
||||||
|
|
||||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
|
||||||
if timings.Exists() {
|
|
||||||
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
|
||||||
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
|
||||||
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
|
||||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
|
||||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
|
||||||
|
|
||||||
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
|
||||||
cachedTokens = int(cachedValue.Int())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rec.metricsMonitor.addMetrics(TokenMetrics{
|
|
||||||
Timestamp: time.Now(),
|
|
||||||
Model: rec.realModelName,
|
|
||||||
CachedTokens: cachedTokens,
|
|
||||||
InputTokens: inputTokens,
|
|
||||||
OutputTokens: outputTokens,
|
|
||||||
PromptPerSecond: promptPerSecond,
|
|
||||||
TokensPerSecond: tokensPerSecond,
|
|
||||||
DurationMs: durationMs,
|
|
||||||
})
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
|
||||||
// Iterate **backwards** through the lines looking for the data payload with
|
|
||||||
// usage data
|
|
||||||
lines := bytes.Split(body, []byte("\n"))
|
|
||||||
|
|
||||||
for i := len(lines) - 1; i >= 0; i-- {
|
|
||||||
line := bytes.TrimSpace(lines[i])
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSE payload always follows "data:"
|
|
||||||
prefix := []byte("data:")
|
|
||||||
if !bytes.HasPrefix(line, prefix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data := bytes.TrimSpace(line[len(prefix):])
|
|
||||||
|
|
||||||
if len(data) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Equal(data, []byte("[DONE]")) {
|
|
||||||
// [DONE] line itself contains nothing of interest.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if gjson.ValidBytes(data) {
|
|
||||||
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
|
||||||
return // short circuit if a metric was recorded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse JSON to extract usage information
|
|
||||||
if gjson.ValidBytes(body) {
|
|
||||||
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MetricsResponseWriter captures the entire response for non-streaming
|
|
||||||
type MetricsResponseWriter struct {
|
|
||||||
gin.ResponseWriter
|
|
||||||
body []byte
|
|
||||||
metricsRecorder *MetricsRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
n, err := w.ResponseWriter.Write(b)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
w.body = append(w.body, b...)
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
w.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MetricsResponseWriter) Header() http.Header {
|
|
||||||
return w.ResponseWriter.Header()
|
|
||||||
}
|
|
||||||
+198
-15
@@ -1,12 +1,18 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
@@ -31,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 {
|
|||||||
return TokenMetricsEventID // defined in events.go
|
return TokenMetricsEventID // defined in events.go
|
||||||
}
|
}
|
||||||
|
|
||||||
// MetricsMonitor parses llama-server output for token statistics
|
// metricsMonitor parses llama-server output for token statistics
|
||||||
type MetricsMonitor struct {
|
type metricsMonitor struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
metrics []TokenMetrics
|
metrics []TokenMetrics
|
||||||
maxMetrics int
|
maxMetrics int
|
||||||
nextID int
|
nextID int
|
||||||
|
logger *LogMonitor
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
||||||
maxMetrics := config.MetricsMaxInMemory
|
mp := &metricsMonitor{
|
||||||
if maxMetrics <= 0 {
|
logger: logger,
|
||||||
maxMetrics = 1000 // Default fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
mp := &MetricsMonitor{
|
|
||||||
maxMetrics: maxMetrics,
|
maxMetrics: maxMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +56,7 @@ func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addMetrics adds a new metric to the collection and publishes an event
|
// addMetrics adds a new metric to the collection and publishes an event
|
||||||
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||||
mp.mu.Lock()
|
mp.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
@@ -66,8 +69,8 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
|||||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMetrics returns a copy of the current metrics
|
// getMetrics returns a copy of the current metrics
|
||||||
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
@@ -76,9 +79,189 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMetricsJSON returns metrics as JSON
|
// getMetricsJSON returns metrics as JSON
|
||||||
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
return json.Marshal(mp.metrics)
|
return json.Marshal(mp.metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrapHandler wraps the proxy handler to extract token metrics
|
||||||
|
// if wrapHandler returns an error it is safe to assume that no
|
||||||
|
// data was sent to the client
|
||||||
|
func (mp *metricsMonitor) wrapHandler(
|
||||||
|
modelID string,
|
||||||
|
writer gin.ResponseWriter,
|
||||||
|
request *http.Request,
|
||||||
|
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||||
|
) error {
|
||||||
|
recorder := newBodyCopier(writer)
|
||||||
|
if err := next(modelID, recorder, request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// after this point we have to assume that data was sent to the client
|
||||||
|
// and we can only log errors but not send them to clients
|
||||||
|
|
||||||
|
if recorder.Status() != http.StatusOK {
|
||||||
|
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
mp.logger.Warn("metrics skipped, empty body")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||||
|
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
mp.addMetrics(tm)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if gjson.ValidBytes(body) {
|
||||||
|
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
|
||||||
|
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
mp.addMetrics(tm)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||||
|
// Iterate **backwards** through the body looking for the data payload with
|
||||||
|
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||||
|
|
||||||
|
// Start from the end of the body and scan backwards for newlines
|
||||||
|
pos := len(body)
|
||||||
|
for pos > 0 {
|
||||||
|
// Find the previous newline (or start of body)
|
||||||
|
lineStart := bytes.LastIndexByte(body[:pos], '\n')
|
||||||
|
if lineStart == -1 {
|
||||||
|
lineStart = 0
|
||||||
|
} else {
|
||||||
|
lineStart++ // Move past the newline
|
||||||
|
}
|
||||||
|
|
||||||
|
line := bytes.TrimSpace(body[lineStart:pos])
|
||||||
|
pos = lineStart - 1 // Move position before the newline for next iteration
|
||||||
|
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE payload always follows "data:"
|
||||||
|
prefix := []byte("data:")
|
||||||
|
if !bytes.HasPrefix(line, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[len(prefix):])
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
// [DONE] line itself contains nothing of interest.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.ValidBytes(data) {
|
||||||
|
return parseMetrics(modelID, start, gjson.ParseBytes(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
|
||||||
|
usage := jsonData.Get("usage")
|
||||||
|
timings := jsonData.Get("timings")
|
||||||
|
if !usage.Exists() && !timings.Exists() {
|
||||||
|
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
|
||||||
|
}
|
||||||
|
// default values
|
||||||
|
cachedTokens := -1 // unknown or missing data
|
||||||
|
outputTokens := 0
|
||||||
|
inputTokens := 0
|
||||||
|
|
||||||
|
// timings data
|
||||||
|
tokensPerSecond := -1.0
|
||||||
|
promptPerSecond := -1.0
|
||||||
|
durationMs := int(time.Since(start).Milliseconds())
|
||||||
|
|
||||||
|
if usage.Exists() {
|
||||||
|
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
||||||
|
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||||
|
if timings.Exists() {
|
||||||
|
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
||||||
|
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
||||||
|
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||||
|
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||||
|
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||||
|
|
||||||
|
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
||||||
|
cachedTokens = int(cachedValue.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
CachedTokens: cachedTokens,
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
PromptPerSecond: promptPerSecond,
|
||||||
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
DurationMs: durationMs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseBodyCopier records the response body and writes to the original response writer
|
||||||
|
// while also capturing it in a buffer for later processing
|
||||||
|
type responseBodyCopier struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
tee io.Writer
|
||||||
|
start time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||||
|
bodyBuffer := &bytes.Buffer{}
|
||||||
|
return &responseBodyCopier{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: bodyBuffer,
|
||||||
|
tee: io.MultiWriter(w, bodyBuffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||||
|
if w.start.IsZero() {
|
||||||
|
w.start = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single write operation that writes to both the response and buffer
|
||||||
|
return w.tee.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Header() http.Header {
|
||||||
|
return w.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) StartTime() time.Time {
|
||||||
|
return w.start
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,693 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||||
|
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, 0, metrics[0].ID)
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 5, len(metrics))
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
assert.Equal(t, i, metrics[i].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 3)
|
||||||
|
|
||||||
|
// Add 5 metrics
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "model",
|
||||||
|
InputTokens: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 3, len(metrics))
|
||||||
|
|
||||||
|
// Should keep the last 3 metrics (IDs 2, 3, 4)
|
||||||
|
assert.Equal(t, 2, metrics[0].ID)
|
||||||
|
assert.Equal(t, 3, metrics[1].ID)
|
||||||
|
assert.Equal(t, 4, metrics[2].ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||||
|
cancel := event.On(func(e TokenMetricsEvent) {
|
||||||
|
receivedEvent <- e
|
||||||
|
})
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-receivedEvent:
|
||||||
|
assert.Equal(t, 0, evt.Metrics.ID)
|
||||||
|
assert.Equal(t, "test-model", evt.Metrics.Model)
|
||||||
|
assert.Equal(t, 100, evt.Metrics.InputTokens)
|
||||||
|
assert.Equal(t, 50, evt.Metrics.OutputTokens)
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for event")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||||
|
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.NotNil(t, metrics)
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||||
|
|
||||||
|
metrics1 := mm.getMetrics()
|
||||||
|
metrics2 := mm.getMetrics()
|
||||||
|
|
||||||
|
// Verify we got copies
|
||||||
|
assert.Equal(t, 2, len(metrics1))
|
||||||
|
assert.Equal(t, 2, len(metrics2))
|
||||||
|
|
||||||
|
// Modify the returned slice shouldn't affect the original
|
||||||
|
metrics1[0].Model = "modified"
|
||||||
|
metrics3 := mm.getMetrics()
|
||||||
|
assert.Equal(t, "model1", metrics3[0].Model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||||
|
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
jsonData, err := mm.getMetricsJSON()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, jsonData)
|
||||||
|
|
||||||
|
var metrics []TokenMetrics
|
||||||
|
err = json.Unmarshal(jsonData, &metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "model1",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
TokensPerSecond: 25.5,
|
||||||
|
})
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "model2",
|
||||||
|
InputTokens: 200,
|
||||||
|
OutputTokens: 100,
|
||||||
|
TokensPerSecond: 30.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
jsonData, err := mm.getMetricsJSON()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var metrics []TokenMetrics
|
||||||
|
err = json.Unmarshal(jsonData, &metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, len(metrics))
|
||||||
|
assert.Equal(t, "model1", metrics[0].Model)
|
||||||
|
assert.Equal(t, "model2", metrics[1].Model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||||
|
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("successful request with timings data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"timings": {
|
||||||
|
"prompt_n": 100,
|
||||||
|
"predicted_n": 50,
|
||||||
|
"prompt_per_second": 150.5,
|
||||||
|
"predicted_per_second": 25.5,
|
||||||
|
"prompt_ms": 500.0,
|
||||||
|
"predicted_ms": 1500.0,
|
||||||
|
"cache_n": 20
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
assert.Equal(t, 20, metrics[0].CachedTokens)
|
||||||
|
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
|
||||||
|
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
|
||||||
|
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||||
|
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||||
|
|
||||||
|
data: {"choices":[{"text":" World"}]}
|
||||||
|
|
||||||
|
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
// When timings data is present, it takes precedence
|
||||||
|
assert.Equal(t, 10, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 20, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte("error"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("not valid json"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
expectedErr := assert.AnError
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return expectedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.Equal(t, expectedErr, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{"result": "ok"}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||||
|
t.Run("captures response body", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
testData := []byte("test response body")
|
||||||
|
n, err := copier.Write(testData)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(testData), n)
|
||||||
|
assert.Equal(t, testData, copier.body.Bytes())
|
||||||
|
assert.Equal(t, string(testData), rec.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets start time on first write", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
assert.True(t, copier.StartTime().IsZero())
|
||||||
|
|
||||||
|
copier.Write([]byte("test"))
|
||||||
|
|
||||||
|
assert.False(t, copier.StartTime().IsZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves headers", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
copier.Header().Set("X-Test", "value")
|
||||||
|
|
||||||
|
assert.Equal(t, "value", rec.Header().Get("X-Test"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves status code", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
copier.WriteHeader(http.StatusCreated)
|
||||||
|
|
||||||
|
// Gin's ResponseWriter tracks status internally
|
||||||
|
assert.Equal(t, http.StatusCreated, copier.Status())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||||
|
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 1000)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
numGoroutines := 10
|
||||||
|
metricsPerGoroutine := 100
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < metricsPerGoroutine; j++ {
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
InputTokens: id*1000 + j,
|
||||||
|
OutputTokens: j,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 100)
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Writer goroutine
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "test-model"})
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Multiple reader goroutines
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 20; j++ {
|
||||||
|
_ = mm.getMetrics()
|
||||||
|
_, _ = mm.getMetricsJSON()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Final check
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 50, len(metrics))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||||
|
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
// Timings should take precedence over usage
|
||||||
|
responseBody := `{
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"completion_tokens": 25
|
||||||
|
},
|
||||||
|
"timings": {
|
||||||
|
"prompt_n": 100,
|
||||||
|
"predicted_n": 50,
|
||||||
|
"prompt_per_second": 150.5,
|
||||||
|
"predicted_per_second": 25.5,
|
||||||
|
"prompt_ms": 500.0,
|
||||||
|
"predicted_ms": 1500.0
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
// Should use timings values, not usage values
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"timings": {
|
||||||
|
"prompt_n": 100,
|
||||||
|
"predicted_n": 50,
|
||||||
|
"prompt_per_second": 150.5,
|
||||||
|
"predicted_per_second": 25.5,
|
||||||
|
"prompt_ms": 500.0,
|
||||||
|
"predicted_ms": 1500.0
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||||
|
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
// Metrics should be found in the last data line before [DONE]
|
||||||
|
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||||
|
|
||||||
|
data: {"choices":[{"text":"Second"}]}
|
||||||
|
|
||||||
|
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `data: not json
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := ``
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
// Empty body should not trigger WrapHandler processing
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 1000)
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
CachedTokens: 100,
|
||||||
|
InputTokens: 500,
|
||||||
|
OutputTokens: 250,
|
||||||
|
PromptPerSecond: 1200.5,
|
||||||
|
TokensPerSecond: 45.8,
|
||||||
|
DurationMs: 5000,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||||
|
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||||
|
mm := newMetricsMonitor(testLogger, 100)
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
CachedTokens: 100,
|
||||||
|
InputTokens: 500,
|
||||||
|
OutputTokens: 250,
|
||||||
|
PromptPerSecond: 1200.5,
|
||||||
|
TokensPerSecond: 45.8,
|
||||||
|
DurationMs: 5000,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -499,6 +499,18 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
startDuration = time.Since(beginStartTime)
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||||
|
// disconnects before the response is sent
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if r == http.ErrAbortHandler {
|
||||||
|
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if p.reverseProxy != nil {
|
if p.reverseProxy != nil {
|
||||||
p.reverseProxy.ServeHTTP(w, r)
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -494,3 +494,74 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
|||||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||||
|
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||||
|
// handled appropriately.
|
||||||
|
//
|
||||||
|
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||||
|
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||||
|
// response is sent from some reason.
|
||||||
|
//
|
||||||
|
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||||
|
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||||
|
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||||
|
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||||
|
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
expectedMessage := "panic_test"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// Start the process
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
|
||||||
|
// Create a custom ResponseWriter that simulates a client disconnect
|
||||||
|
// by panicking when Write is called after headers are sent
|
||||||
|
panicWriter := &panicOnWriteResponseWriter{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
shouldPanic: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a request that will trigger the panic
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||||
|
|
||||||
|
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||||
|
// ProxyRequest should catch and handle this panic gracefully.
|
||||||
|
process.ProxyRequest(panicWriter, req)
|
||||||
|
|
||||||
|
// If we get here, the panic was properly recovered in ProxyRequest
|
||||||
|
// The process should still be in a ready state
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
}
|
||||||
|
|
||||||
|
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||||
|
// to simulate a client disconnect after headers are sent
|
||||||
|
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||||
|
type panicOnWriteResponseWriter struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
shouldPanic bool
|
||||||
|
headerWritten bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
w.headerWritten = true
|
||||||
|
w.ResponseRecorder.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
if w.shouldPanic && w.headerWritten {
|
||||||
|
// Simulate the panic that httputil.ReverseProxy throws
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
return w.ResponseRecorder.Write(b)
|
||||||
|
}
|
||||||
|
|||||||
+42
-14
@@ -36,7 +36,7 @@ type ProxyManager struct {
|
|||||||
upstreamLogger *LogMonitor
|
upstreamLogger *LogMonitor
|
||||||
muxLogger *LogMonitor
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
metricsMonitor *MetricsMonitor
|
metricsMonitor *metricsMonitor
|
||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
@@ -75,6 +75,13 @@ func New(config config.Config) *ProxyManager {
|
|||||||
|
|
||||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
var maxMetrics int
|
||||||
|
if config.MetricsMaxInMemory <= 0 {
|
||||||
|
maxMetrics = 1000 // Default fallback
|
||||||
|
} else {
|
||||||
|
maxMetrics = config.MetricsMaxInMemory
|
||||||
|
}
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
@@ -83,7 +90,7 @@ func New(config config.Config) *ProxyManager {
|
|||||||
muxLogger: stdoutLogger,
|
muxLogger: stdoutLogger,
|
||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
metricsMonitor: NewMetricsMonitor(&config),
|
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
||||||
|
|
||||||
processGroups: make(map[string]*ProcessGroup),
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
@@ -193,27 +200,25 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
c.Next()
|
c.Next()
|
||||||
})
|
})
|
||||||
|
|
||||||
mm := MetricsMiddleware(pm)
|
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support embeddings and reranking
|
// Support embeddings and reranking
|
||||||
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// llama-server's /reranking endpoint + aliases
|
// llama-server's /reranking endpoint + aliases
|
||||||
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// llama-server's /infill endpoint for code infilling
|
// llama-server's /infill endpoint for code infilling
|
||||||
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// llama-server's /completion endpoint
|
// llama-server's /completion endpoint
|
||||||
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
@@ -474,8 +479,23 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
|
originalPath := c.Request.URL.Path
|
||||||
c.Request.URL.Path = remainingPath
|
c.Request.URL.Path = remainingPath
|
||||||
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
|
||||||
|
// attempt to record metrics if it is a POST request
|
||||||
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
@@ -535,11 +555,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
c.Request.ContentLength = int64(len(bodyBytes))
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
|
|
||||||
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
sendModels()
|
sendModels()
|
||||||
sendMetrics(pm.metricsMonitor.GetMetrics())
|
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -198,7 +198,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||||
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -911,76 +911,6 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
|||||||
// t.Logf("%v", response)
|
// t.Logf("%v", response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]config.ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
})
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
||||||
|
|
||||||
// Make a non-streaming request
|
|
||||||
reqBody := `{"model":"model1", "stream": false}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := CreateTestResponseRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
// Check that metrics were recorded
|
|
||||||
metrics := proxy.metricsMonitor.GetMetrics()
|
|
||||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the last metric has the correct model
|
|
||||||
lastMetric := metrics[len(metrics)-1]
|
|
||||||
assert.Equal(t, "model1", lastMetric.Model)
|
|
||||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
|
||||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
|
||||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
|
||||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]config.ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
})
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
||||||
|
|
||||||
// Make a streaming request
|
|
||||||
reqBody := `{"model":"model1", "stream": true}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
|
||||||
w := CreateTestResponseRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
// Check that metrics were recorded
|
|
||||||
metrics := proxy.metricsMonitor.GetMetrics()
|
|
||||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the last metric has the correct model
|
|
||||||
lastMetric := metrics[len(metrics)-1]
|
|
||||||
assert.Equal(t, "model1", lastMetric.Model)
|
|
||||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
|
||||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
|
||||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
|
||||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
|
|||||||
Generated
+3
-3
@@ -3975,9 +3975,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/vite": {
|
"node_modules/vite": {
|
||||||
"version": "6.3.5",
|
"version": "6.4.1",
|
||||||
"resolved": "https://registry.npmjs.org/vite/-/vite-6.3.5.tgz",
|
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
|
||||||
"integrity": "sha512-cZn6NDFE7wdTpINgs++ZJ4N49W2vRp8LCKrn3Ob1kYNtOo21vfDoaV5GzBfLU4MovSAB8uNRm4jgzVQZ+mBzPQ==",
|
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|||||||
+278
-20
@@ -191,42 +191,300 @@ function ModelsPanel() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface HistogramData {
|
||||||
|
bins: number[];
|
||||||
|
min: number;
|
||||||
|
max: number;
|
||||||
|
binSize: number;
|
||||||
|
p99: number;
|
||||||
|
p95: number;
|
||||||
|
p50: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
function TokenHistogram({ data }: { data: HistogramData }) {
|
||||||
|
const { bins, min, max, p50, p95, p99 } = data;
|
||||||
|
const maxCount = Math.max(...bins);
|
||||||
|
|
||||||
|
const height = 120;
|
||||||
|
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
|
||||||
|
|
||||||
|
// Use viewBox for responsive sizing
|
||||||
|
const viewBoxWidth = 600;
|
||||||
|
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||||
|
const chartHeight = height - padding.top - padding.bottom;
|
||||||
|
|
||||||
|
const barWidth = chartWidth / bins.length;
|
||||||
|
const range = max - min;
|
||||||
|
|
||||||
|
// Calculate x position for a given value
|
||||||
|
const getXPosition = (value: number) => {
|
||||||
|
return padding.left + ((value - min) / range) * chartWidth;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mt-2 w-full">
|
||||||
|
<svg
|
||||||
|
viewBox={`0 0 ${viewBoxWidth} ${height}`}
|
||||||
|
className="w-full h-auto"
|
||||||
|
preserveAspectRatio="xMidYMid meet"
|
||||||
|
>
|
||||||
|
{/* Y-axis */}
|
||||||
|
<line
|
||||||
|
x1={padding.left}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={padding.left}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="1"
|
||||||
|
opacity="0.3"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* X-axis */}
|
||||||
|
<line
|
||||||
|
x1={padding.left}
|
||||||
|
y1={height - padding.bottom}
|
||||||
|
x2={viewBoxWidth - padding.right}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="1"
|
||||||
|
opacity="0.3"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Histogram bars */}
|
||||||
|
{bins.map((count, i) => {
|
||||||
|
const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0;
|
||||||
|
const x = padding.left + i * barWidth;
|
||||||
|
const y = height - padding.bottom - barHeight;
|
||||||
|
const binStart = min + i * data.binSize;
|
||||||
|
const binEnd = binStart + data.binSize;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<g key={i}>
|
||||||
|
<rect
|
||||||
|
x={x}
|
||||||
|
y={y}
|
||||||
|
width={Math.max(barWidth - 1, 1)}
|
||||||
|
height={barHeight}
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.6"
|
||||||
|
className="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
|
||||||
|
/>
|
||||||
|
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
|
||||||
|
</g>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
{/* Percentile lines */}
|
||||||
|
<line
|
||||||
|
x1={getXPosition(p50)}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={getXPosition(p50)}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="2"
|
||||||
|
strokeDasharray="4 2"
|
||||||
|
opacity="0.7"
|
||||||
|
className="text-gray-600 dark:text-gray-400"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<line
|
||||||
|
x1={getXPosition(p95)}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={getXPosition(p95)}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="2"
|
||||||
|
strokeDasharray="4 2"
|
||||||
|
opacity="0.7"
|
||||||
|
className="text-orange-500 dark:text-orange-400"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<line
|
||||||
|
x1={getXPosition(p99)}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={getXPosition(p99)}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="2"
|
||||||
|
strokeDasharray="4 2"
|
||||||
|
opacity="0.7"
|
||||||
|
className="text-green-500 dark:text-green-400"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* X-axis labels */}
|
||||||
|
<text
|
||||||
|
x={padding.left}
|
||||||
|
y={height - 5}
|
||||||
|
fontSize="10"
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.6"
|
||||||
|
textAnchor="start"
|
||||||
|
>
|
||||||
|
{min.toFixed(1)}
|
||||||
|
</text>
|
||||||
|
|
||||||
|
<text
|
||||||
|
x={viewBoxWidth - padding.right}
|
||||||
|
y={height - 5}
|
||||||
|
fontSize="10"
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.6"
|
||||||
|
textAnchor="end"
|
||||||
|
>
|
||||||
|
{max.toFixed(1)}
|
||||||
|
</text>
|
||||||
|
|
||||||
|
{/* X-axis label */}
|
||||||
|
<text
|
||||||
|
x={padding.left + chartWidth / 2}
|
||||||
|
y={height - 2}
|
||||||
|
fontSize="10"
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.6"
|
||||||
|
textAnchor="middle"
|
||||||
|
>
|
||||||
|
Tokens/Second Distribution
|
||||||
|
</text>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function StatsPanel() {
|
function StatsPanel() {
|
||||||
const { metrics } = useAPI();
|
const { metrics } = useAPI();
|
||||||
|
|
||||||
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
|
const [totalRequests, totalInputTokens, totalOutputTokens, tokenStats, histogramData] = useMemo(() => {
|
||||||
const totalRequests = metrics.length;
|
const totalRequests = metrics.length;
|
||||||
if (totalRequests === 0) {
|
if (totalRequests === 0) {
|
||||||
return [0, 0, 0];
|
return [0, 0, 0, { p99: 0, p95: 0, p50: 0 }, null];
|
||||||
}
|
}
|
||||||
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||||
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||||
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
|
|
||||||
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
|
// Calculate token statistics using output_tokens and duration_ms
|
||||||
|
// Filter out metrics with invalid duration or output tokens
|
||||||
|
const validMetrics = metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
|
||||||
|
if (validMetrics.length === 0) {
|
||||||
|
return [totalRequests, totalInputTokens, totalOutputTokens, { p99: 0, p95: 0, p50: 0 }, null];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate tokens/second for each valid metric
|
||||||
|
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
|
||||||
|
|
||||||
|
// Sort for percentile calculation
|
||||||
|
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
|
||||||
|
|
||||||
|
// Calculate percentiles - showing speed thresholds where X% of requests are SLOWER (below)
|
||||||
|
// P99: 99% of requests are slower than this speed (99th percentile - fast requests)
|
||||||
|
// P95: 95% of requests are slower than this speed (95th percentile)
|
||||||
|
// P50: 50% of requests are slower than this speed (median)
|
||||||
|
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
|
||||||
|
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
|
||||||
|
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
|
||||||
|
|
||||||
|
// Create histogram data
|
||||||
|
const min = Math.min(...tokensPerSecond);
|
||||||
|
const max = Math.max(...tokensPerSecond);
|
||||||
|
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5))); // Adaptive bin count
|
||||||
|
const binSize = (max - min) / binCount;
|
||||||
|
|
||||||
|
const bins = Array(binCount).fill(0);
|
||||||
|
tokensPerSecond.forEach((value) => {
|
||||||
|
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
||||||
|
bins[binIndex]++;
|
||||||
|
});
|
||||||
|
|
||||||
|
const histogramData = {
|
||||||
|
bins,
|
||||||
|
min,
|
||||||
|
max,
|
||||||
|
binSize,
|
||||||
|
p99,
|
||||||
|
p95,
|
||||||
|
p50,
|
||||||
|
};
|
||||||
|
|
||||||
|
return [
|
||||||
|
totalRequests,
|
||||||
|
totalInputTokens,
|
||||||
|
totalOutputTokens,
|
||||||
|
{
|
||||||
|
p99: p99.toFixed(2),
|
||||||
|
p95: p95.toFixed(2),
|
||||||
|
p50: p50.toFixed(2),
|
||||||
|
},
|
||||||
|
histogramData,
|
||||||
|
];
|
||||||
}, [metrics]);
|
}, [metrics]);
|
||||||
|
|
||||||
|
const nf = new Intl.NumberFormat();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="card">
|
<div className="card">
|
||||||
<div className="rounded-lg overflow-hidden border border-gray-200 dark:border-white/10">
|
<div className="rounded-lg overflow-hidden border border-card-border-inner">
|
||||||
<table className="w-full">
|
<table className="min-w-full divide-y divide-card-border-inner">
|
||||||
<thead>
|
<thead className="bg-secondary">
|
||||||
<tr className="border-b border-gray-200 dark:border-white/10 text-right">
|
<tr>
|
||||||
<th>Requests</th>
|
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">
|
||||||
<th className="border-l border-gray-200 dark:border-white/10">Processed</th>
|
Requests
|
||||||
<th className="border-l border-gray-200 dark:border-white/10">Generated</th>
|
</th>
|
||||||
<th className="border-l border-gray-200 dark:border-white/10">Tokens/Sec</th>
|
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||||
|
Processed
|
||||||
|
</th>
|
||||||
|
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||||
|
Generated
|
||||||
|
</th>
|
||||||
|
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||||
|
Token Stats (tokens/sec)
|
||||||
|
</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
|
||||||
<tr className="text-right">
|
<tbody className="bg-surface divide-y divide-card-border-inner">
|
||||||
<td className="border-r border-gray-200 dark:border-white/10">{totalRequests}</td>
|
<tr className="hover:bg-secondary">
|
||||||
<td className="border-r border-gray-200 dark:border-white/10">
|
<td className="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">{totalRequests}</td>
|
||||||
{new Intl.NumberFormat().format(totalInputTokens)}
|
|
||||||
|
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-sm font-medium">{nf.format(totalInputTokens)}</span>
|
||||||
|
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||||
|
</div>
|
||||||
</td>
|
</td>
|
||||||
<td className="border-r border-gray-200 dark:border-white/10">
|
|
||||||
{new Intl.NumberFormat().format(totalOutputTokens)}
|
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-sm font-medium">{nf.format(totalOutputTokens)}</span>
|
||||||
|
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||||
|
</div>
|
||||||
|
</td>
|
||||||
|
|
||||||
|
<td className="px-4 py-4 border-l border-gray-200 dark:border-white/10">
|
||||||
|
<div className="space-y-3">
|
||||||
|
<div className="grid grid-cols-3 gap-2 items-center">
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-xs text-gray-500 dark:text-gray-400">P50</div>
|
||||||
|
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||||
|
{tokenStats.p50}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-xs text-gray-500 dark:text-gray-400">P95</div>
|
||||||
|
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||||
|
{tokenStats.p95}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-xs text-gray-500 dark:text-gray-400">P99</div>
|
||||||
|
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||||
|
{tokenStats.p99}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{histogramData && <TokenHistogram data={histogramData} />}
|
||||||
|
</div>
|
||||||
</td>
|
</td>
|
||||||
<td>{avgTokensPerSecond}</td>
|
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
|
|||||||
Reference in New Issue
Block a user