proxy: add API key support (#436)
Add configuration support for api keys that are enforced by llama-swap. Keys are stripped before sending them to upstream servers. Updates: #433, #50 and #251
This commit is contained in:
@@ -34,6 +34,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
- `/log` - remote log monitoring
|
- `/log` - remote log monitoring
|
||||||
- `/health` - just returns "OK"
|
- `/health` - just returns "OK"
|
||||||
|
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||||
- ✅ Customizable
|
- ✅ Customizable
|
||||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||||
- Automatic unloading of models after timeout by setting a `ttl`
|
- Automatic unloading of models after timeout by setting a `ttl`
|
||||||
|
|||||||
@@ -70,6 +70,16 @@ sendLoadingState: true
|
|||||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||||
includeAliasesInList: false
|
includeAliasesInList: false
|
||||||
|
|
||||||
|
# apiKeys: require an API key when making requests to inference endpoints
|
||||||
|
# - optional, default: []
|
||||||
|
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||||
|
# - each key is a non-empty string
|
||||||
|
apiKeys:
|
||||||
|
- "sk-hunter2"
|
||||||
|
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||||
|
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||||
|
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||||
|
|
||||||
# macros: a dictionary of string substitutions
|
# macros: a dictionary of string substitutions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros are reusable snippets
|
# - macros are reusable snippets
|
||||||
|
|||||||
@@ -161,6 +161,16 @@ sendLoadingState: true
|
|||||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||||
includeAliasesInList: false
|
includeAliasesInList: false
|
||||||
|
|
||||||
|
# apiKeys: require an API key when making requests to inference endpoints
|
||||||
|
# - optional, default: []
|
||||||
|
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||||
|
# - each key is a non-empty string
|
||||||
|
apiKeys:
|
||||||
|
- "sk-hunter2"
|
||||||
|
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||||
|
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||||
|
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||||
|
|
||||||
# macros: a dictionary of string substitutions
|
# macros: a dictionary of string substitutions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros are reusable snippets
|
# - macros are reusable snippets
|
||||||
|
|||||||
@@ -143,6 +143,9 @@ type Config struct {
|
|||||||
|
|
||||||
// present aliases to /v1/models OpenAI API listing
|
// present aliases to /v1/models OpenAI API listing
|
||||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||||
|
|
||||||
|
// support API keys, see issue #433, #50, #251
|
||||||
|
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -418,6 +421,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
config.Hooks.OnStartup.Preload = toPreload
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check api keys validatity
|
||||||
|
for _, apikey := range config.RequiredAPIKeys {
|
||||||
|
if apikey == "" {
|
||||||
|
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(apikey, " ") {
|
||||||
|
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -761,3 +761,51 @@ models:
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
content: `apiKeys: [""]`,
|
||||||
|
expectedErr: "empty api key found in apiKeys",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blank spaces only",
|
||||||
|
content: `apiKeys: [" "]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: ` `",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains leading space",
|
||||||
|
content: `apiKeys: [" key123"]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains trailing space",
|
||||||
|
content: `apiKeys: ["key123 "]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains middle space",
|
||||||
|
content: `apiKeys: ["key 123"]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty in list with valid keys",
|
||||||
|
content: `apiKeys: ["valid-key", "", "another-key"]`,
|
||||||
|
expectedErr: "empty api key found in apiKeys",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Equal(t, tt.expectedErr, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+73
-19
@@ -256,37 +256,38 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
|
// Protected routes use pm.apiKeyAuth() middleware
|
||||||
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||||
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// Support embeddings and reranking
|
// Support embeddings and reranking
|
||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// llama-server's /reranking endpoint + aliases
|
// llama-server's /reranking endpoint + aliases
|
||||||
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// llama-server's /infill endpoint for code infilling
|
// llama-server's /infill endpoint for code infilling
|
||||||
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// llama-server's /completion endpoint
|
// llama-server's /completion endpoint
|
||||||
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||||
|
|
||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User Interface Endpoints
|
* User Interface Endpoints
|
||||||
@@ -298,9 +299,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||||
c.Redirect(http.StatusFound, "/ui/models")
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
})
|
})
|
||||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.proxyToUpstream)
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
})
|
})
|
||||||
@@ -765,6 +766,59 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||||
|
// Returns a pass-through handler if no API keys are configured.
|
||||||
|
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||||
|
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||||
|
return func(c *gin.Context) { c.Next() }
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
xApiKey := c.GetHeader("x-api-key")
|
||||||
|
|
||||||
|
var bearerKey string
|
||||||
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||||
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both headers present, they must match
|
||||||
|
if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use x-api-key first, then Authorization
|
||||||
|
providedKey := xApiKey
|
||||||
|
if providedKey == "" {
|
||||||
|
providedKey = bearerKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key
|
||||||
|
valid := false
|
||||||
|
for _, key := range pm.config.RequiredAPIKeys {
|
||||||
|
if providedKey == key {
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip auth headers to prevent leakage to upstream
|
||||||
|
c.Request.Header.Del("Authorization")
|
||||||
|
c.Request.Header.Del("x-api-key")
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses(StopImmediately)
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ type Model struct {
|
|||||||
|
|
||||||
func addApiHandlers(pm *ProxyManager) {
|
func addApiHandlers(pm *ProxyManager) {
|
||||||
// Add API endpoints for React to consume
|
// Add API endpoints for React to consume
|
||||||
apiGroup := pm.ginEngine.Group("/api")
|
// Protected with API key authentication
|
||||||
|
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||||
{
|
{
|
||||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||||
|
|||||||
@@ -1187,3 +1187,103 @@ func TestProxyManager_ApiGetVersion(t *testing.T) {
|
|||||||
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_APIKeyAuth(t *testing.T) {
|
||||||
|
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(testConfig)
|
||||||
|
defer proxy.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
t.Run("valid key in x-api-key header", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
req.Header.Set("x-api-key", "valid-key-1")
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
req.Header.Set("Authorization", "Bearer valid-key-2")
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("both headers with matching keys", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
req.Header.Set("x-api-key", "valid-key-1")
|
||||||
|
req.Header.Set("Authorization", "Bearer valid-key-1")
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("both headers with different keys returns 400", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
req.Header.Set("x-api-key", "valid-key-1")
|
||||||
|
req.Header.Set("Authorization", "Bearer valid-key-2")
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "do not match")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid key returns 401", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
req.Header.Set("x-api-key", "invalid-key")
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing key returns 401", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
||||||
|
// Config without RequiredAPIKeys - auth should be disabled
|
||||||
|
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(testConfig)
|
||||||
|
defer proxy.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
t.Run("requests pass without API key when not configured", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user