Compare commits

...

5 Commits

Author SHA1 Message Date
Benson Wong 4413881b2d proxy: actually add /v1/responses endpoint (#449)
ref: #448
2026-01-01 13:35:45 -08:00
Benson Wong 8df5e8563b proxy: add /v1/responses and /v1/audio/voices endpoints (#448)
Updates #433
Fixes #442 #226
2026-01-01 12:52:12 -08:00
Benson Wong 7931212d3e proxy: add v1/images/edits API endpoint (#447)
Updates #433
2026-01-01 12:43:06 -08:00
Benson Wong 3dc36032fb proxy: skip very slow tests in -short test mode (#446)
* proxy: skip very slow tests in -short test mode
* CLAUDE.md: update testing instructions
2025-12-31 14:08:56 -08:00
Benson Wong addb98646f proxy: add support for basic authorization (#445)
Fixes #444 where the UI with api keys did not work. The choice to use
http basic authorization is for simple, automatic browser support. No
changes to the UI were necessary. Just use an API key as the password,
no user name is required.
2025-12-31 13:42:35 -08:00
6 changed files with 88 additions and 28 deletions
+4 -2
View File
@@ -11,8 +11,10 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
## Testing ## Testing
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests. - Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests.
## Workflow Tasks ## Workflow Tasks
+3
View File
@@ -18,10 +18,13 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ OpenAI API supported endpoints: - ✅ OpenAI API supported endpoints:
- `v1/completions` - `v1/completions`
- `v1/chat/completions` - `v1/chat/completions`
- `v1/responses`
- `v1/embeddings` - `v1/embeddings`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- `v1/audio/voices`
- `v1/images/generations` - `v1/images/generations`
- `v1/images/edits`
- ✅ Anthropic API supported endpoints: - ✅ Anthropic API supported endpoints:
- `v1/messages` - `v1/messages`
- ✅ llama-server (llama.cpp) supported endpoints - ✅ llama-server (llama.cpp) supported endpoints
+4
View File
@@ -395,6 +395,10 @@ func TestProcess_StopImmediately(t *testing.T) {
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates // Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command // the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) { func TestProcess_ForceStopWithKill(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("skipping SIGTERM test on Windows ") t.Skip("skipping SIGTERM test on Windows ")
} }
+4
View File
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true // TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time. // and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{ Models: map[string]config.ModelConfig{
+22 -10
View File
@@ -3,6 +3,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
@@ -276,6 +277,7 @@ func (pm *ProxyManager) setupGinEngine() {
// Set up routes using the Gin engine // Set up routes using the Gin engine
// Protected routes use pm.apiKeyAuth() middleware // Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
@@ -298,8 +300,10 @@ func (pm *ProxyManager) setupGinEngine() {
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler) pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
@@ -825,23 +829,30 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
xApiKey := c.GetHeader("x-api-key") xApiKey := c.GetHeader("x-api-key")
var bearerKey string var bearerKey string
var basicKey string
if auth := c.GetHeader("Authorization"); auth != "" { if auth := c.GetHeader("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") { if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ") bearerKey = strings.TrimPrefix(auth, "Bearer ")
} else if strings.HasPrefix(auth, "Basic ") {
// Basic Auth: base64(username:password), password is the API key
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) == 2 {
basicKey = parts[1] // password is the API key
}
}
} }
} }
// If both headers present, they must match // Use first key found: Basic, then Bearer, then x-api-key
if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey { var providedKey string
pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match") if basicKey != "" {
c.Abort() providedKey = basicKey
return } else if bearerKey != "" {
}
// Use x-api-key first, then Authorization
providedKey := xApiKey
if providedKey == "" {
providedKey = bearerKey providedKey = bearerKey
} else {
providedKey = xApiKey
} }
// Validate key // Validate key
@@ -854,6 +865,7 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
} }
if !valid { if !valid {
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key") pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
c.Abort() c.Abort()
return return
+51 -16
View File
@@ -3,6 +3,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel return r.closeChannel
} }
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder { func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{ return &TestResponseRecorder{
httptest.NewRecorder(), httptest.NewRecorder(),
@@ -523,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
} }
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/" model1Config.Proxy = "http://localhost:10001/"
@@ -1253,18 +1254,6 @@ func TestProxyManager_APIKeyAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
}) })
t.Run("both headers with different keys returns 400", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
req.Header.Set("Authorization", "Bearer valid-key-2")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "do not match")
})
t.Run("invalid key returns 401", func(t *testing.T) { t.Run("invalid key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}` reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
@@ -1284,6 +1273,52 @@ func TestProxyManager_APIKeyAuth(t *testing.T) {
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Equal(t, http.StatusUnauthorized, w.Code)
}) })
t.Run("valid key in Basic Auth header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
// Basic Auth: base64("anyuser:valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
})
} }
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) { func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {