diff --git a/config-schema.json b/config-schema.json index 63913981..069b0dc5 100644 --- a/config-schema.json +++ b/config-schema.json @@ -273,6 +273,58 @@ }, "additionalProperties": false, "description": "A dictionary of event triggers and actions. Only supported hook is on_startup." + }, + "logToStdout": { + "type": "string", + "enum": [ + "proxy", + "upstream", + "both", + "none" + ], + "default": "proxy", + "description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout." + }, + "apiKeys": { + "type": "array", + "items": { + "type": "string", + "minLength": 1 + }, + "default": [], + "description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string." + }, + "peers": { + "type": "object", + "additionalProperties": { + "type": "object", + "required": [ + "proxy", + "models" + ], + "properties": { + "proxy": { + "type": "string", + "format": "uri", + "description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value." + }, + "apiKey": { + "type": "string", + "default": "", + "description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer and x-api-key: ." + }, + "models": { + "type": "array", + "items": { + "type": "string", + "minLength": 1 + }, + "description": "A list of models served by the peer." + } + } + }, + "default": {}, + "description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap." } } } diff --git a/config.example.yaml b/config.example.yaml index 3ade9089..ab8030e4 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -341,3 +341,35 @@ hooks: # otherwise models will be loaded and swapped out preload: - "llama" + +# peers: a dictionary of remote peers and models they provide +# - optional, default empty dictionary +# - peers can be another llama-swap +# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap +peers: + # keys is the peer'd ID + llama-swap-peer: + # proxy: a valid base URL to proxy requests to + # - required + # - requested path to llama-swap will be appended to the end of the proxy value + proxy: http://192.168.1.23 + # models: a list of models served by the peer + # - required + models: + - model_a + - model_b + - embeddings/model_c + openrouter: + proxy: https://openrouter.ai/api + # apiKey: a string key to be injected into the request + # - optional, default: "" + # - if blank, no key will be added to the request + # - key will be injected into headers: Authorization: Bearer and x-api-key: + apiKey: sk-your-openrouter-key + models: + - meta-llama/llama-3.1-8b-instruct + - qwen/qwen3-235b-a22b-2507 + - deepseek/deepseek-v3.2 + - z-ai/glm-4.7 + - moonshotai/kimi-k2-0905 + - minimax/minimax-m2.1 diff --git a/docs/configuration.md b/docs/configuration.md index 48d0c58d..5aac2706 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -86,7 +86,7 @@ llama-swap supports many more features to customize how you want to manage your ## Full Configuration Example > [!NOTE] -> This is a copy of `config.example.yaml`. Always check that for the most up to date examples. +> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations. ```yaml # add this modeline for validation in vscode @@ -432,4 +432,36 @@ hooks: # otherwise models will be loaded and swapped out preload: - "llama" + +# peers: a dictionary of remote peers and models they provide +# - optional, default empty dictionary +# - peers can be another llama-swap +# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap +peers: + # keys is the peer'd ID + llama-swap-peer: + # proxy: a valid base URL to proxy requests to + # - required + # - requested path to llama-swap will be appended to the end of the proxy value + proxy: http://192.168.1.23 + # models: a list of models served by the peer + # - required + models: + - model_a + - model_b + - embeddings/model_c + openrouter: + proxy: https://openrouter.ai/api + # apiKey: a string key to be injected into the request + # - optional, default: "" + # - if blank, no key will be added to the request + # - key will be injected into headers: Authorization: Bearer and x-api-key: + apiKey: sk-your-openrouter-key + models: + - meta-llama/llama-3.1-8b-instruct + - qwen/qwen3-235b-a22b-2507 + - deepseek/deepseek-v3.2 + - z-ai/glm-4.7 + - moonshotai/kimi-k2-0905 + - minimax/minimax-m2.1 ``` diff --git a/proxy/config/config.go b/proxy/config/config.go index 9a46e4d8..078d27fd 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -146,6 +146,9 @@ type Config struct { // support API keys, see issue #433, #50, #251 RequiredAPIKeys []string `yaml:"apiKeys"` + + // support remote peers, see issue #433, #296 + Peers PeerDictionaryConfig `yaml:"peers"` } func (c *Config) RealModelName(search string) (string, bool) { diff --git a/proxy/config/peer.go b/proxy/config/peer.go new file mode 100644 index 00000000..4d5ecfb9 --- /dev/null +++ b/proxy/config/peer.go @@ -0,0 +1,47 @@ +package config + +import ( + "fmt" + "net/url" +) + +type PeerDictionaryConfig map[string]PeerConfig +type PeerConfig struct { + Proxy string `yaml:"proxy"` + ProxyURL *url.URL `yaml:"-"` + ApiKey string `yaml:"apiKey"` + Models []string `yaml:"models"` +} + +func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type rawPeerConfig PeerConfig + defaults := rawPeerConfig{ + Proxy: "", + ApiKey: "", + Models: []string{}, + } + + if err := unmarshal(&defaults); err != nil { + return err + } + + // Validate proxy is not empty + if defaults.Proxy == "" { + return fmt.Errorf("proxy is required") + } + + // Validate proxy is a valid URL and store the parsed value + parsedURL, err := url.Parse(defaults.Proxy) + if err != nil { + return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err) + } + defaults.ProxyURL = parsedURL + + // Validate models is not empty + if len(defaults.Models) == 0 { + return fmt.Errorf("peer models can not be empty") + } + + *c = PeerConfig(defaults) + return nil +} diff --git a/proxy/config/peer_test.go b/proxy/config/peer_test.go new file mode 100644 index 00000000..d02f619d --- /dev/null +++ b/proxy/config/peer_test.go @@ -0,0 +1,139 @@ +package config + +import ( + "testing" + + "gopkg.in/yaml.v3" +) + +func TestPeerConfig_UnmarshalYAML(t *testing.T) { + tests := []struct { + name string + yaml string + wantErr string + }{ + { + name: "valid config", + yaml: ` +proxy: http://192.168.1.23 +models: + - model_a + - model_b +`, + wantErr: "", + }, + { + name: "valid config with apiKey", + yaml: ` +proxy: https://openrouter.ai/api +apiKey: sk-test-key +models: + - meta-llama/llama-3.1-8b-instruct +`, + wantErr: "", + }, + { + name: "missing proxy", + yaml: ` +models: + - model_a +`, + wantErr: "proxy is required", + }, + { + name: "empty proxy", + yaml: ` +proxy: "" +models: + - model_a +`, + wantErr: "proxy is required", + }, + { + name: "invalid proxy URL", + yaml: ` +proxy: "://invalid" +models: + - model_a +`, + wantErr: "invalid peer proxy URL", + }, + { + name: "missing models", + yaml: ` +proxy: http://localhost:8080 +`, + wantErr: "peer models can not be empty", + }, + { + name: "empty models", + yaml: ` +proxy: http://localhost:8080 +models: [] +`, + wantErr: "peer models can not be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var config PeerConfig + err := yaml.Unmarshal([]byte(tt.yaml), &config) + + if tt.wantErr == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.wantErr) + } else if !contains(err.Error(), tt.wantErr) { + t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + } + }) + } +} + +func TestPeerConfig_ProxyURL(t *testing.T) { + yamlData := ` +proxy: http://192.168.1.23:8080/api +apiKey: sk-test +models: + - model_a +` + var config PeerConfig + err := yaml.Unmarshal([]byte(yamlData), &config) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if config.ProxyURL == nil { + t.Fatal("ProxyURL should not be nil") + } + + if config.ProxyURL.Host != "192.168.1.23:8080" { + t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host) + } + + if config.ProxyURL.Scheme != "http" { + t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme) + } + + if config.ProxyURL.Path != "/api" { + t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchSubstring(s, substr) +} + +func searchSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go index c948d3df..a3b07de2 100644 --- a/proxy/metrics_monitor.go +++ b/proxy/metrics_monitor.go @@ -2,6 +2,8 @@ package proxy import ( "bytes" + "compress/flate" + "compress/gzip" "encoding/json" "fmt" "io" @@ -96,6 +98,12 @@ func (mp *metricsMonitor) wrapHandler( next func(modelID string, w http.ResponseWriter, r *http.Request) error, ) error { recorder := newBodyCopier(writer) + + // Filter Accept-Encoding to only include encodings we can decompress for metrics + if ae := request.Header.Get("Accept-Encoding"); ae != "" { + request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae)) + } + if err := next(modelID, recorder, request); err != nil { return err } @@ -108,17 +116,36 @@ func (mp *metricsMonitor) wrapHandler( return nil } + // Initialize default metrics - these will always be recorded + tm := TokenMetrics{ + Timestamp: time.Now(), + Model: modelID, + DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), + } + body := recorder.body.Bytes() if len(body) == 0 { - mp.logger.Warn("metrics skipped, empty body") + mp.logger.Warn("metrics: empty body, recording minimal metrics") + mp.addMetrics(tm) return nil } - if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") { - if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { - mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path) - } else { + // Decompress if needed + if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" { + var err error + body, err = decompressBody(body, encoding) + if err != nil { + mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.addMetrics(tm) + return nil + } + } + + if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") { + if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { + mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path) + } else { + tm = parsed } } else { if gjson.ValidBytes(body) { @@ -127,18 +154,18 @@ func (mp *metricsMonitor) wrapHandler( timings := parsed.Get("timings") if usage.Exists() || timings.Exists() { - if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { - mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path) + if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { + mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path) } else { - mp.addMetrics(tm) + tm = parsedMetrics } } - } else { - mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path) + mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path) } } + mp.addMetrics(tm) return nil } @@ -251,6 +278,25 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) }, nil } +// decompressBody decompresses the body based on Content-Encoding header +func decompressBody(body []byte, encoding string) ([]byte, error) { + switch strings.ToLower(strings.TrimSpace(encoding)) { + case "gzip": + reader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + return nil, err + } + defer reader.Close() + return io.ReadAll(reader) + case "deflate": + reader := flate.NewReader(bytes.NewReader(body)) + defer reader.Close() + return io.ReadAll(reader) + default: + return body, nil // Return as-is for unknown/no encoding + } +} + // responseBodyCopier records the response body and writes to the original response writer // while also capturing it in a buffer for later processing type responseBodyCopier struct { @@ -289,3 +335,25 @@ func (w *responseBodyCopier) Header() http.Header { func (w *responseBodyCopier) StartTime() time.Time { return w.start } + +// filterAcceptEncoding filters the Accept-Encoding header to only include +// encodings we can decompress (gzip, deflate). This respects the client's +// preferences while ensuring we can parse response bodies for metrics. +func filterAcceptEncoding(acceptEncoding string) string { + if acceptEncoding == "" { + return "" + } + + supported := map[string]bool{"gzip": true, "deflate": true} + var filtered []string + + for _, part := range strings.Split(acceptEncoding, ",") { + // Parse encoding and optional quality value (e.g., "gzip;q=1.0") + encoding := strings.TrimSpace(strings.Split(part, ";")[0]) + if supported[strings.ToLower(encoding)] { + filtered = append(filtered, strings.TrimSpace(part)) + } + } + + return strings.Join(filtered, ", ") +} diff --git a/proxy/metrics_monitor_test.go b/proxy/metrics_monitor_test.go index fb353884..b68cf191 100644 --- a/proxy/metrics_monitor_test.go +++ b/proxy/metrics_monitor_test.go @@ -1,6 +1,9 @@ package proxy import ( + "bytes" + "compress/flate" + "compress/gzip" "encoding/json" "net/http" "net/http/httptest" @@ -291,7 +294,7 @@ data: [DONE] assert.Equal(t, 0, len(metrics)) }) - t.Run("empty response body does not record metrics", func(t *testing.T) { + t.Run("empty response body records minimal metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10) nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { @@ -307,10 +310,13 @@ data: [DONE] assert.NoError(t, err) metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 0, metrics[0].InputTokens) + assert.Equal(t, 0, metrics[0].OutputTokens) }) - t.Run("invalid JSON does not record metrics", func(t *testing.T) { + t.Run("invalid JSON records minimal metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10) nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { @@ -328,7 +334,10 @@ data: [DONE] assert.NoError(t, err) // Errors after response is sent are logged, not returned metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 0, metrics[0].InputTokens) + assert.Equal(t, 0, metrics[0].OutputTokens) }) t.Run("next handler error is propagated", func(t *testing.T) { @@ -350,7 +359,7 @@ data: [DONE] assert.Equal(t, 0, len(metrics)) }) - t.Run("response without usage or timings does not record metrics", func(t *testing.T) { + t.Run("response without usage or timings records minimal metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10) responseBody := `{"result": "ok"}` @@ -367,10 +376,13 @@ data: [DONE] ginCtx, _ := gin.CreateTestContext(rec) err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) - assert.NoError(t, err) // Errors after response is sent are logged, not returned + assert.NoError(t, err) metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 0, metrics[0].InputTokens) + assert.Equal(t, 0, metrics[0].OutputTokens) }) } @@ -598,7 +610,7 @@ data: [DONE] assert.Equal(t, 50, metrics[0].OutputTokens) }) - t.Run("handles streaming with no valid JSON", func(t *testing.T) { + t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10) responseBody := `data: not json @@ -619,13 +631,16 @@ data: [DONE] ginCtx, _ := gin.CreateTestContext(rec) err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) - assert.NoError(t, err) // Errors after response is sent are logged, not returned + assert.NoError(t, err) metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 0, metrics[0].InputTokens) + assert.Equal(t, 0, metrics[0].OutputTokens) }) - t.Run("handles empty streaming response", func(t *testing.T) { + t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) { mm := newMetricsMonitor(testLogger, 10) responseBody := `` @@ -642,11 +657,13 @@ data: [DONE] ginCtx, _ := gin.CreateTestContext(rec) err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) - // Empty body should not trigger WrapHandler processing assert.NoError(t, err) metrics := mm.getMetrics() - assert.Equal(t, 0, len(metrics)) + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 0, metrics[0].InputTokens) + assert.Equal(t, 0, metrics[0].OutputTokens) }) } @@ -691,3 +708,127 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) { mm.addMetrics(metric) } } + +func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) { + t.Run("gzip encoded response", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}` + + // Compress with gzip + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + gzWriter.Write([]byte(responseBody)) + gzWriter.Close() + compressedBody := buf.Bytes() + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "gzip") + w.WriteHeader(http.StatusOK) + w.Write(compressedBody) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 100, metrics[0].InputTokens) + assert.Equal(t, 50, metrics[0].OutputTokens) + }) + + t.Run("deflate encoded response", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}` + + // Compress with deflate + var buf bytes.Buffer + flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression) + flateWriter.Write([]byte(responseBody)) + flateWriter.Close() + compressedBody := buf.Bytes() + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "deflate") + w.WriteHeader(http.StatusOK) + w.Write(compressedBody) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 200, metrics[0].InputTokens) + assert.Equal(t, 75, metrics[0].OutputTokens) + }) + + t.Run("invalid gzip data records minimal metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + // Invalid compressed data + invalidData := []byte("this is not gzip data") + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "gzip") + w.WriteHeader(http.StatusOK) + w.Write(invalidData) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) // Should not return error, just log warning + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 0, metrics[0].InputTokens) + assert.Equal(t, 0, metrics[0].OutputTokens) + }) + + t.Run("unknown encoding treated as uncompressed", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "unknown-encoding") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, 300, metrics[0].InputTokens) + assert.Equal(t, 100, metrics[0].OutputTokens) + }) +} diff --git a/proxy/peerproxy.go b/proxy/peerproxy.go new file mode 100644 index 00000000..876f6bff --- /dev/null +++ b/proxy/peerproxy.go @@ -0,0 +1,127 @@ +package proxy + +import ( + "fmt" + "net" + "net/http" + "net/http/httputil" + "runtime" + "sort" + "strings" + "time" + + "github.com/mostlygeek/llama-swap/proxy/config" +) + +type peerProxyMember struct { + peerID string + reverseProxy *httputil.ReverseProxy + apiKey string +} + +type PeerProxy struct { + peers config.PeerDictionaryConfig + proxyMap map[string]*peerProxyMember +} + +func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) { + proxyMap := make(map[string]*peerProxyMember) + + // Sort peer IDs for consistent iteration order + peerIDs := make([]string, 0, len(peers)) + for peerID := range peers { + peerIDs = append(peerIDs, peerID) + } + sort.Strings(peerIDs) + + // Create a shared transport with reasonable timeouts for peer connections + // these can be tuned with feedback later + peerTransport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, // Connection timeout + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + + for _, peerID := range peerIDs { + peer := peers[peerID] + // Create reverse proxy for this peer + reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL) + reverseProxy.Transport = peerTransport + + // Wrap Director to set Host header for remote hosts (not localhost) + originalDirector := reverseProxy.Director + reverseProxy.Director = func(req *http.Request) { + originalDirector(req) + // Ensure Host header matches target URL for remote proxying + req.Host = req.URL.Host + } + + reverseProxy.ModifyResponse = func(resp *http.Response) error { + if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { + resp.Header.Set("X-Accel-Buffering", "no") + } + return nil + } + + reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err) + errMsg := fmt.Sprintf("peer proxy error: %v", err) + if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") { + errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)" + } + http.Error(w, errMsg, http.StatusBadGateway) + } + + pp := &peerProxyMember{ + peerID: peerID, + reverseProxy: reverseProxy, + apiKey: peer.ApiKey, + } + + // Map each model to this peer's proxy + for _, modelID := range peer.Models { + if _, found := proxyMap[modelID]; found { + proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID) + continue + } + proxyMap[modelID] = pp + } + } + + return &PeerProxy{ + peers: peers, + proxyMap: proxyMap, + }, nil +} + +func (p *PeerProxy) HasPeerModel(modelID string) bool { + _, found := p.proxyMap[modelID] + return found +} + +func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig { + return p.peers +} + +func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error { + pp, found := p.proxyMap[model_id] + if !found { + return fmt.Errorf("no peer proxy found for model %s", model_id) + } + + // Inject API key if configured for this peer + if pp.apiKey != "" { + request.Header.Set("Authorization", "Bearer "+pp.apiKey) + request.Header.Set("x-api-key", pp.apiKey) + } + + pp.reverseProxy.ServeHTTP(writer, request) + return nil +} diff --git a/proxy/peerproxy_test.go b/proxy/peerproxy_test.go new file mode 100644 index 00000000..c6158dc6 --- /dev/null +++ b/proxy/peerproxy_test.go @@ -0,0 +1,268 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPeerProxy_EmptyPeers(t *testing.T) { + peers := config.PeerDictionaryConfig{} + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + assert.NotNil(t, pm) + assert.Empty(t, pm.proxyMap) +} + +func TestNewPeerProxy_SinglePeer(t *testing.T) { + proxyURL, _ := url.Parse("http://peer1.example.com:8080") + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL, + ApiKey: "test-key", + Models: []string{"model-a", "model-b"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + assert.Len(t, pm.proxyMap, 2) + assert.True(t, pm.HasPeerModel("model-a")) + assert.True(t, pm.HasPeerModel("model-b")) + assert.False(t, pm.HasPeerModel("model-c")) +} + +func TestNewPeerProxy_MultiplePeers(t *testing.T) { + proxyURL1, _ := url.Parse("http://peer1.example.com:8080") + proxyURL2, _ := url.Parse("http://peer2.example.com:8080") + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL1, + Models: []string{"model-a", "model-b"}, + }, + "peer2": config.PeerConfig{ + Proxy: "http://peer2.example.com:8080", + ProxyURL: proxyURL2, + Models: []string{"model-c", "model-d"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + assert.Len(t, pm.proxyMap, 4) + assert.True(t, pm.HasPeerModel("model-a")) + assert.True(t, pm.HasPeerModel("model-b")) + assert.True(t, pm.HasPeerModel("model-c")) + assert.True(t, pm.HasPeerModel("model-d")) +} + +func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) { + // When the same model is in multiple peers, only the first (lexicographically by peer ID) + // should be mapped, and a warning should be logged + proxyURL1, _ := url.Parse("http://peer1.example.com:8080") + proxyURL2, _ := url.Parse("http://peer2.example.com:8080") + peers := config.PeerDictionaryConfig{ + "alpha-peer": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL1, + Models: []string{"duplicate-model"}, + }, + "beta-peer": config.PeerConfig{ + Proxy: "http://peer2.example.com:8080", + ProxyURL: proxyURL2, + Models: []string{"duplicate-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + // Should only have one entry for the duplicate model + assert.Len(t, pm.proxyMap, 1) + assert.True(t, pm.HasPeerModel("duplicate-model")) +} + +func TestHasPeerModel(t *testing.T) { + proxyURL, _ := url.Parse("http://peer1.example.com:8080") + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: "http://peer1.example.com:8080", + ProxyURL: proxyURL, + Models: []string{"existing-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + assert.True(t, pm.HasPeerModel("existing-model")) + assert.False(t, pm.HasPeerModel("non-existing-model")) +} + +func TestProxyRequest_ModelNotFound(t *testing.T) { + peers := config.PeerDictionaryConfig{} + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + err = pm.ProxyRequest("non-existing-model", w, req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model") +} + +func TestProxyRequest_Success(t *testing.T) { + // Create a test server to act as the peer + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("response from peer")) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + err = pm.ProxyRequest("test-model", w, req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "response from peer", w.Body.String()) +} + +func TestProxyRequest_ApiKeyInjection(t *testing.T) { + // Create a test server that checks for the Authorization header + var receivedAuthHeader string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + ApiKey: "secret-api-key", + Models: []string{"test-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + err = pm.ProxyRequest("test-model", w, req) + assert.NoError(t, err) + assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader) +} + +func TestProxyRequest_NoApiKey(t *testing.T) { + // Create a test server that checks for the Authorization header + var receivedAuthHeader string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + ApiKey: "", // No API key + Models: []string{"test-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + err = pm.ProxyRequest("test-model", w, req) + assert.NoError(t, err) + assert.Empty(t, receivedAuthHeader) +} + +func TestProxyRequest_HostHeaderSet(t *testing.T) { + // Create a test server that checks the Host header + var receivedHost string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHost = r.Host + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + err = pm.ProxyRequest("test-model", w, req) + assert.NoError(t, err) + // The Host header should be set to the target URL's host + assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:")) +} + +func TestProxyRequest_SSEHeaderModification(t *testing.T) { + // Create a test server that returns SSE content type + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + proxyURL, _ := url.Parse(testServer.URL) + peers := config.PeerDictionaryConfig{ + "peer1": config.PeerConfig{ + Proxy: testServer.URL, + ProxyURL: proxyURL, + Models: []string{"test-model"}, + }, + } + + pm, err := NewPeerProxy(peers, testLogger) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + err = pm.ProxyRequest("test-model", w, req) + assert.NoError(t, err) + // The X-Accel-Buffering header should be set to "no" for SSE + assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering")) +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 99e814f3..5a0752ac 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -50,6 +50,9 @@ type ProxyManager struct { buildDate string commit string version string + + // peer proxy see: #296, #433 + peerProxy *PeerProxy } func New(proxyConfig config.Config) *ProxyManager { @@ -133,6 +136,12 @@ func New(proxyConfig config.Config) *ProxyManager { maxMetrics = proxyConfig.MetricsMaxInMemory } + peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger) + if err != nil { + proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err) + peerProxy = nil + } + pm := &ProxyManager{ config: proxyConfig, ginEngine: gin.New(), @@ -151,6 +160,8 @@ func New(proxyConfig config.Config) *ProxyManager { buildDate: "unknown", commit: "abcd1234", version: "0", + + peerProxy: peerProxy, } // create the process groups @@ -166,22 +177,29 @@ func New(proxyConfig config.Config) *ProxyManager { // do it in the background, don't block startup -- not sure if good idea yet go func() { discardWriter := &DiscardWriter{} - for _, realModelName := range proxyConfig.Hooks.OnStartup.Preload { - proxyLogger.Infof("Preloading model: %s", realModelName) - processGroup, _, err := pm.swapProcessGroup(realModelName) + for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload { + modelID, ok := proxyConfig.RealModelName(preloadModelName) + + if !ok { + proxyLogger.Warnf("Preload model %s not found in config", preloadModelName) + continue + } + + proxyLogger.Infof("Preloading model: %s", modelID) + processGroup, err := pm.swapProcessGroup(modelID) if err != nil { event.Emit(ModelPreloadedEvent{ - ModelName: realModelName, + ModelName: modelID, Success: false, }) - proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err) + proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err) continue } else { req, _ := http.NewRequest("GET", "/", nil) - processGroup.ProxyRequest(realModelName, discardWriter, req) + processGroup.ProxyRequest(modelID, discardWriter, req) event.Emit(ModelPreloadedEvent{ - ModelName: realModelName, + ModelName: modelID, Success: true, }) } @@ -399,16 +417,10 @@ func (pm *ProxyManager) Shutdown() { pm.shutdownCancel() } -func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { - // de-alias the real model name and get a real one - realModelName, found := pm.config.RealModelName(requestedModel) - if !found { - return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel) - } - +func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) { processGroup := pm.findGroupByModelName(realModelName) if processGroup == nil { - return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel) + return nil, fmt.Errorf("could not find process group for model %s", realModelName) } if processGroup.exclusive { @@ -420,54 +432,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, } } - return processGroup, realModelName, nil + return processGroup, nil } func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := make([]gin.H, 0, len(pm.config.Models)) createdTime := time.Now().Unix() + newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H { + record := gin.H{ + "id": modelId, + "object": "model", + "created": createdTime, + "owned_by": "llama-swap", + } + + if name := strings.TrimSpace(modelConfig.Name); name != "" { + record["name"] = name + } + if desc := strings.TrimSpace(modelConfig.Description); desc != "" { + record["description"] = desc + } + + // Add metadata if present + if len(modelConfig.Metadata) > 0 { + record["meta"] = gin.H{ + "llamaswap": modelConfig.Metadata, + } + } + return record + } + for id, modelConfig := range pm.config.Models { if modelConfig.Unlisted { continue } - newRecord := func(modelId string) gin.H { - record := gin.H{ - "id": modelId, - "object": "model", - "created": createdTime, - "owned_by": "llama-swap", - } - - if name := strings.TrimSpace(modelConfig.Name); name != "" { - record["name"] = name - } - if desc := strings.TrimSpace(modelConfig.Description); desc != "" { - record["description"] = desc - } - - // Add metadata if present - if len(modelConfig.Metadata) > 0 { - record["meta"] = gin.H{ - "llamaswap": modelConfig.Metadata, - } - } - return record - } - - data = append(data, newRecord(id)) + data = append(data, newRecord(id, modelConfig)) // Include aliases if pm.config.IncludeAliasesInList { for _, alias := range modelConfig.Aliases { if alias := strings.TrimSpace(alias); alias != "" { - data = append(data, newRecord(alias)) + data = append(data, newRecord(alias, modelConfig)) } } } } + if pm.peerProxy != nil { + for peerID, peer := range pm.peerProxy.ListPeers() { + // add peer models + for _, modelID := range peer.Models { + // Skip unlisted models if not showing them + record := newRecord(modelID, config.ModelConfig{ + Name: fmt.Sprintf("%s: %s", peerID, modelID), + Metadata: map[string]any{ + "peerID": peerID, + }, + }) + + data = append(data, record) + } + } + } + // Sort by the "id" key sort.Slice(data, func(i, j int) bool { si, _ := data[i]["id"].(string) @@ -506,8 +535,8 @@ func (pm *ProxyManager) findModelInPath(path string) (searchName string, realNam searchModelName = searchModelName + "/" + part } - if real, ok := pm.config.RealModelName(searchModelName); ok { - return searchModelName, real, "/" + strings.Join(parts[i+1:], "/"), true + if modelID, ok := pm.config.RealModelName(searchModelName); ok { + return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true } } @@ -517,23 +546,22 @@ func (pm *ProxyManager) findModelInPath(path string) (searchName string, realNam func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { upstreamPath := c.Param("upstreamPath") - searchModelName, modelName, remainingPath, modelFound := pm.findModelInPath(upstreamPath) + searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath) if !modelFound { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } - // Check if this is exactly a model name with no additional path - // and doesn't end with a trailing slash + // Redirect /upstream/modelname to /upstream/modelname/ for URL consistency. + // This ensures relative URLs in upstream responses resolve correctly and + // provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the + // HTTP method (301 would downgrade to GET). if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") { - // Build new URL with query parameters preserved newPath := "/upstream/" + searchModelName + "/" if c.Request.URL.RawQuery != "" { newPath += "?" + c.Request.URL.RawQuery } - - // Use 308 for non-GET/HEAD requests to preserve method if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { c.Redirect(http.StatusMovedPermanently, newPath) } else { @@ -542,7 +570,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { return } - processGroup, realModelName, err := pm.swapProcessGroup(modelName) + processGroup, err := pm.swapProcessGroup(modelID) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return @@ -554,15 +582,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { // attempt to record metrics if it is a POST request if pm.metricsMonitor != nil && c.Request.Method == "POST" { - if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { + if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) - pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath) + pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath) return } } else { - if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { + if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath) + pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath) return } } @@ -581,41 +609,54 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { return } - realModelName, found := pm.config.RealModelName(requestedModel) - if !found { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel)) - return - } + // Look for a matching local model first + var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error - processGroup, _, err := pm.swapProcessGroup(realModelName) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) - return - } - - // issue #69 allow custom model names to be sent to upstream - useModelName := pm.config.Models[realModelName].UseModelName - if useModelName != "" { - bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) + modelID, found := pm.config.RealModelName(requestedModel) + if found { + processGroup, err := pm.swapProcessGroup(modelID) if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return } - } - // issue #174 strip parameters from the JSON body - stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams() - if err != nil { // just log it and continue - pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error()) - } else { - for _, param := range stripParams { - pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param) - bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) + // issue #69 allow custom model names to be sent to upstream + useModelName := pm.config.Models[modelID].UseModelName + if useModelName != "" { + bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) return } } + + // issue #174 strip parameters from the JSON body + stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams() + if err != nil { // just log it and continue + pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error()) + } else { + for _, param := range stripParams { + pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param) + bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) + return + } + } + } + + pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) + nextHandler = processGroup.ProxyRequest + } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { + pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) + modelID = requestedModel + nextHandler = pm.peerProxy.ProxyRequest + + } + + if nextHandler == nil { + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel)) + return } c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -628,19 +669,19 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { // issue #366 extract values that downstream handlers may need isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool() ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming) - ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName) + ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID) c.Request = c.Request.WithContext(ctx) if pm.metricsMonitor != nil && c.Request.Method == "POST" { - if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { + if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName) + pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID) return } } else { - if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { + if err := nextHandler(modelID, c.Writer, c.Request); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) + pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID) return } } @@ -660,7 +701,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { return } - processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) + modelID, found := pm.config.RealModelName(requestedModel) + if !found { + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel)) + return + } + + processGroup, err := pm.swapProcessGroup(modelID) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return @@ -678,7 +725,7 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { // If this is the model field and we have a profile, use just the model name if key == "model" { // # issue #69 allow custom model names to be sent to upstream - useModelName := pm.config.Models[realModelName].UseModelName + useModelName := pm.config.Models[modelID].UseModelName if useModelName != "" { fieldValue = useModelName @@ -749,9 +796,9 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { modifiedReq.ContentLength = int64(requestBuffer.Len()) // Use the modified request for proxying - if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { + if err := processGroup.ProxyRequest(modelID, c.Writer, modifiedReq); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) + pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, modelID) return } } diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index 629617dd..fe4326d0 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -18,6 +18,7 @@ type Model struct { Description string `json:"description"` State string `json:"state"` Unlisted bool `json:"unlisted"` + PeerID string `json:"peerID"` } func addApiHandlers(pm *ProxyManager) { @@ -83,6 +84,18 @@ func (pm *ProxyManager) getModelStatus() []Model { }) } + // Iterate over the peer models + if pm.peerProxy != nil { + for peerID, peer := range pm.peerProxy.ListPeers() { + for _, modelID := range peer.Models { + models = append(models, Model{ + Id: modelID, + PeerID: peerID, + }) + } + } + } + return models } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index dbff98ac..2330b32b 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -223,17 +223,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { model2Config.Name = " " // empty whitespace only strings will get ignored model2Config.Description = " " - config := config.Config{ + cfg := config.Config{ HealthCheckTimeout: 15, Models: map[string]config.ModelConfig{ "model1": model1Config, "model2": model2Config, "model3": getTestSimpleResponderConfig("model3"), }, + Peers: map[string]config.PeerConfig{ + "peer1": { + Proxy: "http://peer1:8080", + Models: []string{"peer-model-a", "peer-model-b"}, + }, + }, LogLevel: "error", } - proxy := New(config) + proxy := New(cfg) // Create a test request req := httptest.NewRequest("GET", "/v1/models", nil) @@ -258,14 +264,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { t.Fatalf("Failed to parse JSON response: %v", err) } - // Check the number of models returned - assert.Len(t, response.Data, 3) + // Check the number of models returned (3 local + 2 peer models) + assert.Len(t, response.Data, 5) // Check the details of each model expectedModels := map[string]struct{}{ - "model1": {}, - "model2": {}, - "model3": {}, + "model1": {}, + "model2": {}, + "model3": {}, + "peer-model-a": {}, + "peer-model-b": {}, } // make all models @@ -296,6 +304,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { description, ok := model["description"].(string) assert.True(t, ok, "description should be a string") assert.Equal(t, "Model 1 description is used for testing", description) + } else if modelID == "peer-model-a" || modelID == "peer-model-b" { + // Peer models should have meta.llamaswap.peerID + meta, exists := model["meta"] + assert.True(t, exists, "peer model should have meta field") + metaMap, ok := meta.(map[string]interface{}) + assert.True(t, ok, "meta should be a map") + llamaswap, exists := metaMap["llamaswap"] + assert.True(t, exists, "meta should have llamaswap field") + llamaswapMap, ok := llamaswap.(map[string]interface{}) + assert.True(t, ok, "llamaswap should be a map") + peerID, exists := llamaswapMap["peerID"] + assert.True(t, exists, "llamaswap should have peerID field") + assert.Equal(t, "peer1", peerID) } else { _, exists := model["name"] assert.False(t, exists, "unexpected name field for model: %s", modelID) @@ -1287,3 +1308,215 @@ func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) }) } + +// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration +// in proxyInferenceHandler for issue #433 +func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) { + t.Run("requests to peer models are proxied", func(t *testing.T) { + // Create a test server to act as the peer + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`)) + })) + defer peerServer.Close() + + // Create config with peers but no local model for "peer-model" + configStr := fmt.Sprintf(` +logLevel: error +peers: + test-peer: + proxy: %s + models: + - peer-model +models: + local-model: + cmd: %s -port ${PORT} -silent -respond local-model +`, peerServer.URL, getSimpleResponderPath()) + + testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) + assert.NoError(t, err) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + reqBody := `{"model":"peer-model"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "from-peer") + }) + + t.Run("local models take precedence over peer models", func(t *testing.T) { + // Create a test server to act as the peer - should NOT be called + peerCalled := false + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + peerCalled = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"response":"from-peer"}`)) + })) + defer peerServer.Close() + + // Create config where "shared-model" exists both locally and on peer + configStr := fmt.Sprintf(` +logLevel: error +peers: + test-peer: + proxy: %s + models: + - shared-model +models: + shared-model: + cmd: %s -port ${PORT} -silent -respond local-response +`, peerServer.URL, getSimpleResponderPath()) + + testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) + assert.NoError(t, err) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + reqBody := `{"model":"shared-model"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "local-response") + assert.False(t, peerCalled, "peer should not be called when local model exists") + }) + + t.Run("unknown model returns error", func(t *testing.T) { + // Create a test server to act as the peer + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer peerServer.Close() + + configStr := fmt.Sprintf(` +logLevel: error +peers: + test-peer: + proxy: %s + models: + - peer-model +models: + local-model: + cmd: %s -port ${PORT} -silent -respond local-model +`, peerServer.URL, getSimpleResponderPath()) + + testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) + assert.NoError(t, err) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + reqBody := `{"model":"unknown-model"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "could not find suitable inference handler") + }) + + t.Run("peer API key is injected into request", func(t *testing.T) { + var receivedAuthHeader string + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"response":"ok"}`)) + })) + defer peerServer.Close() + + configStr := fmt.Sprintf(` +logLevel: error +peers: + test-peer: + proxy: %s + apiKey: secret-peer-key + models: + - peer-model +models: + local-model: + cmd: %s -port ${PORT} -silent -respond local-model +`, peerServer.URL, getSimpleResponderPath()) + + testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) + assert.NoError(t, err) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + reqBody := `{"model":"peer-model"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader) + }) + + t.Run("no peers configured - unknown model returns error", func(t *testing.T) { + testConfig := config.AddDefaultGroupToConfig(config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "local-model": getTestSimpleResponderConfig("local-model"), + }, + LogLevel: "error", + }) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + // peerProxy exists but has no peer models configured + assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model")) + + reqBody := `{"model":"unknown-model"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "could not find suitable inference handler") + }) + + t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) { + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("data: test\n\n")) + })) + defer peerServer.Close() + + configStr := fmt.Sprintf(` +logLevel: error +peers: + test-peer: + proxy: %s + models: + - peer-model +models: + local-model: + cmd: %s -port ${PORT} -silent -respond local-model +`, peerServer.URL, getSimpleResponderPath()) + + testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr)) + assert.NoError(t, err) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + reqBody := `{"model":"peer-model"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering")) + }) +} diff --git a/ui/src/contexts/APIProvider.tsx b/ui/src/contexts/APIProvider.tsx index 3740a1f6..294d4ed0 100644 --- a/ui/src/contexts/APIProvider.tsx +++ b/ui/src/contexts/APIProvider.tsx @@ -10,6 +10,7 @@ export interface Model { name: string; description: string; unlisted: boolean; + peerID: string; } interface APIProviderType { @@ -70,7 +71,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider const [versionInfo, setVersionInfo] = useState({ build_date: "unknown", commit: "unknown", - version: "unknown" + version: "unknown", }); //const apiEventSource = useRef(null); @@ -166,7 +167,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider }, []); useEffect(() => { - // fetch version + // fetch version const fetchVersion = async () => { try { const response = await fetch("/api/version"); @@ -180,7 +181,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider } }; - if (connectionStatus === 'connected') { + if (connectionStatus === "connected") { fetchVersion(); } }, [connectionStatus]); @@ -265,7 +266,19 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider connectionStatus, versionInfo, }), - [models, listModels, unloadAllModels, unloadSingleModel, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics, connectionStatus, versionInfo] + [ + models, + listModels, + unloadAllModels, + unloadSingleModel, + loadModel, + enableAPIEvents, + proxyLogs, + upstreamLogs, + metrics, + connectionStatus, + versionInfo, + ] ); return {children}; diff --git a/ui/src/pages/Models.tsx b/ui/src/pages/Models.tsx index c9793bdd..e7dc4357 100644 --- a/ui/src/pages/Models.tsx +++ b/ui/src/pages/Models.tsx @@ -44,8 +44,24 @@ function ModelsPanel() { const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name const [menuOpen, setMenuOpen] = useState(false); - const filteredModels = useMemo(() => { - return models.filter((model) => showUnlisted || !model.unlisted); + const { regularModels, peerModelsByPeerId } = useMemo(() => { + const filtered = models.filter((model) => showUnlisted || !model.unlisted); + const peerModels = filtered.filter((m) => m.peerID); + + // Group peer models by peerID + const grouped = peerModels.reduce((acc, model) => { + const peerId = model.peerID || "unknown"; + if (!acc[peerId]) { + acc[peerId] = []; + } + acc[peerId].push(model); + return acc; + }, {} as Record); + + return { + regularModels: filtered.filter((m) => !m.peerID), + peerModelsByPeerId: grouped, + }; }, [models, showUnlisted]); const handleUnloadAllModels = useCallback(async () => { @@ -151,7 +167,7 @@ function ModelsPanel() { - {filteredModels.map((model) => ( + {regularModels.map((model) => ( @@ -186,6 +202,34 @@ function ModelsPanel() { ))} + + {Object.keys(peerModelsByPeerId).length > 0 && ( + <> +

Peer Models

+ {Object.entries(peerModelsByPeerId) + .sort(([a], [b]) => a.localeCompare(b)) + .map(([peerId, models]) => ( +
+ + + + + + + + {models.map((model) => ( + + + + ))} + +
{peerId}
+ {model.id} +
+
+ ))} + + )} ); @@ -223,11 +267,7 @@ function TokenHistogram({ data }: { data: HistogramData }) { return (
- + {/* Y-axis */} {/* X-axis labels */} - + {min.toFixed(1)}