Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f |
@@ -11,8 +11,10 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||||
|
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||||
|
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||||
|
|
||||||
## Workflow Tasks
|
## Workflow Tasks
|
||||||
|
|
||||||
|
|||||||
@@ -18,10 +18,13 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
- ✅ OpenAI API supported endpoints:
|
- ✅ OpenAI API supported endpoints:
|
||||||
- `v1/completions`
|
- `v1/completions`
|
||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
|
- `v1/responses`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
|
- `v1/audio/voices`
|
||||||
- `v1/images/generations`
|
- `v1/images/generations`
|
||||||
|
- `v1/images/edits`
|
||||||
- ✅ Anthropic API supported endpoints:
|
- ✅ Anthropic API supported endpoints:
|
||||||
- `v1/messages`
|
- `v1/messages`
|
||||||
- ✅ llama-server (llama.cpp) supported endpoints
|
- ✅ llama-server (llama.cpp) supported endpoints
|
||||||
|
|||||||
@@ -395,6 +395,10 @@ func TestProcess_StopImmediately(t *testing.T) {
|
|||||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
// the upstream command
|
// the upstream command
|
||||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("skipping SIGTERM test on Windows ")
|
t.Skip("skipping SIGTERM test on Windows ")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
// and multiple requests are made in parallel, only one process is running at a time.
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]config.ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
|
|||||||
+22
-10
@@ -3,6 +3,7 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -276,6 +277,7 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
// Protected routes use pm.apiKeyAuth() middleware
|
// Protected routes use pm.apiKeyAuth() middleware
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||||
@@ -298,8 +300,10 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||||
|
|
||||||
@@ -825,23 +829,30 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
|||||||
xApiKey := c.GetHeader("x-api-key")
|
xApiKey := c.GetHeader("x-api-key")
|
||||||
|
|
||||||
var bearerKey string
|
var bearerKey string
|
||||||
|
var basicKey string
|
||||||
if auth := c.GetHeader("Authorization"); auth != "" {
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||||
if strings.HasPrefix(auth, "Bearer ") {
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
} else if strings.HasPrefix(auth, "Basic ") {
|
||||||
|
// Basic Auth: base64(username:password), password is the API key
|
||||||
|
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||||
|
parts := strings.SplitN(string(decoded), ":", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
basicKey = parts[1] // password is the API key
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If both headers present, they must match
|
// Use first key found: Basic, then Bearer, then x-api-key
|
||||||
if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey {
|
var providedKey string
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match")
|
if basicKey != "" {
|
||||||
c.Abort()
|
providedKey = basicKey
|
||||||
return
|
} else if bearerKey != "" {
|
||||||
}
|
|
||||||
|
|
||||||
// Use x-api-key first, then Authorization
|
|
||||||
providedKey := xApiKey
|
|
||||||
if providedKey == "" {
|
|
||||||
providedKey = bearerKey
|
providedKey = bearerKey
|
||||||
|
} else {
|
||||||
|
providedKey = xApiKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate key
|
// Validate key
|
||||||
@@ -854,6 +865,7 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !valid {
|
if !valid {
|
||||||
|
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||||
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
|
|||||||
+51
-16
@@ -3,6 +3,7 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
|||||||
return r.closeChannel
|
return r.closeChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *TestResponseRecorder) closeClient() {
|
|
||||||
r.closeChannel <- true
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||||
return &TestResponseRecorder{
|
return &TestResponseRecorder{
|
||||||
httptest.NewRecorder(),
|
httptest.NewRecorder(),
|
||||||
@@ -523,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Shutdown(t *testing.T) {
|
func TestProxyManager_Shutdown(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
// make broken model configurations
|
// make broken model configurations
|
||||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||||
model1Config.Proxy = "http://localhost:10001/"
|
model1Config.Proxy = "http://localhost:10001/"
|
||||||
@@ -1253,18 +1254,6 @@ func TestProxyManager_APIKeyAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("both headers with different keys returns 400", func(t *testing.T) {
|
|
||||||
reqBody := `{"model":"model1"}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
req.Header.Set("x-api-key", "valid-key-1")
|
|
||||||
req.Header.Set("Authorization", "Bearer valid-key-2")
|
|
||||||
w := CreateTestResponseRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
|
||||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "do not match")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("invalid key returns 401", func(t *testing.T) {
|
t.Run("invalid key returns 401", func(t *testing.T) {
|
||||||
reqBody := `{"model":"model1"}`
|
reqBody := `{"model":"model1"}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
@@ -1284,6 +1273,52 @@ func TestProxyManager_APIKeyAuth(t *testing.T) {
|
|||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("valid key in Basic Auth header", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
// Basic Auth: base64("anyuser:valid-key-1")
|
||||||
|
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
|
||||||
|
req.Header.Set("Authorization", "Basic "+credentials)
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
|
||||||
|
req.Header.Set("Authorization", "Basic "+credentials)
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
req.Header.Set("x-api-key", "valid-key-1")
|
||||||
|
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
|
||||||
|
req.Header.Set("Authorization", "Basic "+credentials)
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user