Add Peer Model Support (#438)
This PR allows a single llama-swap to be the central proxy for models served by other inference servers. The peer servers can be another llama-swap or any API that supports the /v1/* inference endpoint. Updates: #433, #299 Closes: #296
This commit is contained in:
+78
-10
@@ -2,6 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -96,6 +98,12 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,17 +116,36 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics skipped, empty body")
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
@@ -127,18 +154,18 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -251,6 +278,25 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
@@ -289,3 +335,25 @@ func (w *responseBodyCopier) Header() http.Header {
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user