From 53b32f3601a9ef74ca1c59f46d7a1f1a0273f214 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Tue, 23 Dec 2025 23:39:33 -0800 Subject: [PATCH] proxy: add API key support (#436) Add configuration support for api keys that are enforced by llama-swap. Keys are stripped before sending them to upstream servers. Updates: #433, #50 and #251 --- README.md | 1 + config.example.yaml | 10 ++++ docs/configuration.md | 10 ++++ proxy/config/config.go | 14 +++++ proxy/config/config_test.go | 48 +++++++++++++++++ proxy/proxymanager.go | 92 ++++++++++++++++++++++++++------- proxy/proxymanager_api.go | 3 +- proxy/proxymanager_test.go | 100 ++++++++++++++++++++++++++++++++++++ 8 files changed, 258 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 09b80b9e..daf934bb 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/log` - remote log monitoring - `/health` - just returns "OK" +- ✅ API Key support - define keys to restrict access to API endpoints - ✅ Customizable - Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - Automatic unloading of models after timeout by setting a `ttl` diff --git a/config.example.yaml b/config.example.yaml index 923fb825..3ade9089 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -70,6 +70,16 @@ sendLoadingState: true # all fields except for Id so chat UIs can use the alias equivalent to the original. includeAliasesInList: false +# apiKeys: require an API key when making requests to inference endpoints +# - optional, default: [] +# - when empty (the default) authorization will not be checked as llama-swap is default-allow +# - each key is a non-empty string +apiKeys: + - "sk-hunter2" + # hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )" + - "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx" + - "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb" + # macros: a dictionary of string substitutions # - optional, default: empty dictionary # - macros are reusable snippets diff --git a/docs/configuration.md b/docs/configuration.md index 852a4a02..48d0c58d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -161,6 +161,16 @@ sendLoadingState: true # all fields except for Id so chat UIs can use the alias equivalent to the original. includeAliasesInList: false +# apiKeys: require an API key when making requests to inference endpoints +# - optional, default: [] +# - when empty (the default) authorization will not be checked as llama-swap is default-allow +# - each key is a non-empty string +apiKeys: + - "sk-hunter2" + # hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )" + - "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx" + - "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb" + # macros: a dictionary of string substitutions # - optional, default: empty dictionary # - macros are reusable snippets diff --git a/proxy/config/config.go b/proxy/config/config.go index c812204d..9a46e4d8 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -143,6 +143,9 @@ type Config struct { // present aliases to /v1/models OpenAI API listing IncludeAliasesInList bool `yaml:"includeAliasesInList"` + + // support API keys, see issue #433, #50, #251 + RequiredAPIKeys []string `yaml:"apiKeys"` } func (c *Config) RealModelName(search string) (string, bool) { @@ -418,6 +421,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { config.Hooks.OnStartup.Preload = toPreload } + // check api keys validatity + for _, apikey := range config.RequiredAPIKeys { + if apikey == "" { + return Config{}, fmt.Errorf("empty api key found in apiKeys") + } + + if strings.Contains(apikey, " ") { + return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey) + } + } + return config, nil } diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index e624a8ce..ab358e66 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -761,3 +761,51 @@ models: }) } } + +func TestConfig_APIKeys_Invalid(t *testing.T) { + tests := []struct { + name string + content string + expectedErr string + }{ + { + name: "empty string", + content: `apiKeys: [""]`, + expectedErr: "empty api key found in apiKeys", + }, + { + name: "blank spaces only", + content: `apiKeys: [" "]`, + expectedErr: "api key cannot contain spaces: ` `", + }, + { + name: "contains leading space", + content: `apiKeys: [" key123"]`, + expectedErr: "api key cannot contain spaces: ` key123`", + }, + { + name: "contains trailing space", + content: `apiKeys: ["key123 "]`, + expectedErr: "api key cannot contain spaces: `key123 `", + }, + { + name: "contains middle space", + content: `apiKeys: ["key 123"]`, + expectedErr: "api key cannot contain spaces: `key 123`", + }, + { + name: "empty in list with valid keys", + content: `apiKeys: ["valid-key", "", "another-key"]`, + expectedErr: "empty api key found in apiKeys", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := LoadConfigFromReader(strings.NewReader(tt.content)) + if assert.Error(t, err) { + assert.Equal(t, tt.expectedErr, err.Error()) + } + }) + } +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 2636f470..99e814f3 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -256,37 +256,38 @@ func (pm *ProxyManager) setupGinEngine() { }) // Set up routes using the Gin engine - pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler) + // Protected routes use pm.apiKeyAuth() middleware + pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler) // Support legacy /v1/completions api, see issue #12 - pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler) // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) - pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.proxyInferenceHandler) // Support embeddings and reranking - pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.proxyInferenceHandler) // llama-server's /reranking endpoint + aliases - pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler) - pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler) + pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler) // llama-server's /infill endpoint for code infilling - pm.ginEngine.POST("/infill", pm.proxyInferenceHandler) + pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.proxyInferenceHandler) // llama-server's /completion endpoint - pm.ginEngine.POST("/completion", pm.proxyInferenceHandler) + pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.proxyInferenceHandler) // Support audio/speech endpoint - pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler) - pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) + pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler) + pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler) - pm.ginEngine.GET("/v1/models", pm.listModelsHandler) + pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler) // in proxymanager_loghandlers.go - pm.ginEngine.GET("/logs", pm.sendLogsHandlers) - pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) - pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.streamLogsHandler) + pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers) + pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler) + pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler) /** * User Interface Endpoints @@ -298,9 +299,9 @@ func (pm *ProxyManager) setupGinEngine() { pm.ginEngine.GET("/upstream", func(c *gin.Context) { c.Redirect(http.StatusFound, "/ui/models") }) - pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream) - pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) - pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) + pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.proxyToUpstream) + pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler) + pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler) pm.ginEngine.GET("/health", func(c *gin.Context) { c.String(http.StatusOK, "OK") }) @@ -765,6 +766,59 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag } } +// apiKeyAuth returns a middleware that validates API keys if configured. +// Returns a pass-through handler if no API keys are configured. +func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc { + if len(pm.config.RequiredAPIKeys) == 0 { + return func(c *gin.Context) { c.Next() } + } + + return func(c *gin.Context) { + xApiKey := c.GetHeader("x-api-key") + + var bearerKey string + if auth := c.GetHeader("Authorization"); auth != "" { + if strings.HasPrefix(auth, "Bearer ") { + bearerKey = strings.TrimPrefix(auth, "Bearer ") + } + } + + // If both headers present, they must match + if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey { + pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match") + c.Abort() + return + } + + // Use x-api-key first, then Authorization + providedKey := xApiKey + if providedKey == "" { + providedKey = bearerKey + } + + // Validate key + valid := false + for _, key := range pm.config.RequiredAPIKeys { + if providedKey == key { + valid = true + break + } + } + + if !valid { + pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key") + c.Abort() + return + } + + // Strip auth headers to prevent leakage to upstream + c.Request.Header.Del("Authorization") + c.Request.Header.Del("x-api-key") + + c.Next() + } +} + func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { pm.StopProcesses(StopImmediately) c.String(http.StatusOK, "OK") diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index a296ee8c..629617dd 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -22,7 +22,8 @@ type Model struct { func addApiHandlers(pm *ProxyManager) { // Add API endpoints for React to consume - apiGroup := pm.ginEngine.Group("/api") + // Protected with API key authentication + apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth()) { apiGroup.POST("/models/unload", pm.apiUnloadAllModels) apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 4ae024e6..dbff98ac 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -1187,3 +1187,103 @@ func TestProxyManager_ApiGetVersion(t *testing.T) { assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key]) } } + +func TestProxyManager_APIKeyAuth(t *testing.T) { + testConfig := config.AddDefaultGroupToConfig(config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"}, + LogLevel: "error", + }) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + t.Run("valid key in x-api-key header", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("x-api-key", "valid-key-1") + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("valid key in Authorization Bearer header", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("Authorization", "Bearer valid-key-2") + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("both headers with matching keys", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("x-api-key", "valid-key-1") + req.Header.Set("Authorization", "Bearer valid-key-1") + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("both headers with different keys returns 400", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("x-api-key", "valid-key-1") + req.Header.Set("Authorization", "Bearer valid-key-2") + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "do not match") + }) + + t.Run("invalid key returns 401", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("x-api-key", "invalid-key") + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Contains(t, w.Body.String(), "unauthorized") + }) + + t.Run("missing key returns 401", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) +} + +func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) { + // Config without RequiredAPIKeys - auth should be disabled + testConfig := config.AddDefaultGroupToConfig(config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + LogLevel: "error", + }) + + proxy := New(testConfig) + defer proxy.StopProcesses(StopImmediately) + + t.Run("requests pass without API key when not configured", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) +}