02e015fa49
This is a huge backend change that essentially started with rewriting the concurrency handling for processes and blew up to a refactor of the entire application. In short these are the improvements: **Better state and life cycle management:** Life cycle management of processes has always been the trickiest part of the code. Juggling mutex locks between multiple locations to reduce race conditions was complex. Too complex for my feeble brain to build a simple mental model around as llama-swap gained more features. All of that has been refactored. Most of the locks are gone, replaced with a single run() that owns all state changes. There is one place to start from now to understand and extend routing logic. The improved life cycle management makes it easier to implement more complex swap optimization strategies in the future like #727. **Collation of requests:** llama-swap previously handled requests and swapping in the order they came in. For example requests for models in this order ABCABC would result in 5 swaps. Now those requests are handled in this order AABBCC. The result is less time waiting for swap under a high churn request queue. This fixes #588 #612. A possible future enhancement is to support a starvation parameter so swap can be forced when models have been waiting too long. **Shared base implementation for groups and swap matrix:** During the refactor it became clear that much of the swapping logic was shared between these two implementations. That is not surprising considering the swap matrix was added many moons after groups. Now they share a common base and their specific swap strategies are implemented into the swapPlanner interface. Requests for bespoke or specific swapping scenarios is a common theme in the issues. Now users can implement whatever bespoke and weird swapping strategy they want in their own fork. Just ask your agent of choice to implement swapPlanner. I'll still remaining more conservative on what actually lands in core llama-swap and will continue to evaluate PRs if the changes is good for everyone or just one specific use case. **AI / Agentic Disclosure:** I paid very close attention to the low level swap concurrency design and implementation. It's important to keep that essential part reliable, boring and no surprises. Backwards compatibility was also maintained, even the one way non-exclusive group model loading behaviour that people have rightly pointed out be a weird design decision. With the underlying swap core done the web server, api and UI sitting on top were largely ported over with Claude Code and Opus 4.7 in multiple phases. If you're curious I kept the changes in docs/newrouter-todo.md. I did several passes to make sure things weren't left behind. However, even frontier LLMs at the time of this PR still make small decisions that don't make a lot of sense. They get shit wrong all the time, just in small subtle way. That said, there's likely to be some new bugs introduced with this massive refactor. I'm fairly confident that there's no major architectural flaws that would cause goal seeking agents to make dumb, ugly code decisions. For a little while the legacy llama-swap will be available under cmd/legacy/llama-swap. The plan is to eventually delete that entry point as well as the proxy package. On a bit of a personal note, this PR is exciting and a bit sad for me. I hand wrote much of the original code and this PR ultimately replaces much of it. While the old code served as a good reference for the agent to implement the new stuff it still a bit sad to eventually delete it all.
1882 lines
56 KiB
Go
1882 lines
56 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/rand"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/mostlygeek/llama-swap/internal/config"
|
|
"github.com/mostlygeek/llama-swap/internal/event"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
|
|
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
|
|
// The tests can panic otherwise:
|
|
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
|
|
// See: https://github.com/gin-gonic/gin/issues/1815
|
|
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
|
|
type TestResponseRecorder struct {
|
|
*httptest.ResponseRecorder
|
|
closeChannel chan bool
|
|
}
|
|
|
|
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
|
return r.closeChannel
|
|
}
|
|
|
|
func CreateTestResponseRecorder() *TestResponseRecorder {
|
|
return &TestResponseRecorder{
|
|
httptest.NewRecorder(),
|
|
make(chan bool, 1),
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
for _, modelName := range []string{"model1", "model2"} {
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), modelName)
|
|
}
|
|
}
|
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
groups:
|
|
G1:
|
|
swap: true
|
|
exclusive: false
|
|
members:
|
|
- model1
|
|
G2:
|
|
swap: true
|
|
exclusive: false
|
|
members:
|
|
- model2
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
tests := []string{"model1", "model2"}
|
|
for _, requestedModel := range tests {
|
|
t.Run(requestedModel, func(t *testing.T) {
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), requestedModel)
|
|
})
|
|
}
|
|
|
|
// make sure there's two loaded models
|
|
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
|
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
|
}
|
|
|
|
// Test that a persistent group is not affected by the swapping behaviour of
|
|
// other groups.
|
|
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
groups:
|
|
forever:
|
|
swap: true
|
|
exclusive: false
|
|
persistent: true
|
|
members:
|
|
- model2
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
// make requests to load all models, loading model1 should not affect model2
|
|
tests := []string{"model2", "model1"}
|
|
for _, requestedModel := range tests {
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), requestedModel)
|
|
}
|
|
|
|
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
|
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
|
}
|
|
|
|
// When a request for a different model comes in ProxyManager should wait until
|
|
// the first request is complete before swapping. Both requests should complete
|
|
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping slow test")
|
|
}
|
|
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
model3:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model3
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
results := map[string]string{}
|
|
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex
|
|
|
|
for key := range cfg.Models {
|
|
wg.Add(1)
|
|
go func(key string) {
|
|
defer wg.Done()
|
|
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
|
}
|
|
|
|
mu.Lock()
|
|
var response map[string]interface{}
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
result, ok := response["responseMessage"].(string)
|
|
assert.Equal(t, ok, true)
|
|
results[key] = result
|
|
mu.Unlock()
|
|
}(key)
|
|
|
|
<-time.After(time.Millisecond)
|
|
}
|
|
|
|
wg.Wait()
|
|
assert.Len(t, results, len(cfg.Models))
|
|
|
|
for key, result := range results {
|
|
assert.Equal(t, key, result)
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
name: "Model 1"
|
|
description: "Model 1 description is used for testing"
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
name: " "
|
|
description: " "
|
|
model3:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model3
|
|
peers:
|
|
peer1:
|
|
proxy: http://peer1:8080
|
|
models:
|
|
- peer-model-a
|
|
- peer-model-b
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
|
|
// Create a test request
|
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
|
req.Header.Add("Origin", "i-am-the-origin")
|
|
w := CreateTestResponseRecorder()
|
|
|
|
// Call the listModelsHandler
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
// Check the response status code
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
// Check for Access-Control-Allow-Origin
|
|
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
|
|
|
// Parse the JSON response
|
|
var response struct {
|
|
Data []map[string]interface{} `json:"data"`
|
|
}
|
|
|
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// Check the number of models returned (3 local + 2 peer models)
|
|
assert.Len(t, response.Data, 5)
|
|
|
|
// Check the details of each model
|
|
expectedModels := map[string]struct{}{
|
|
"model1": {},
|
|
"model2": {},
|
|
"model3": {},
|
|
"peer-model-a": {},
|
|
"peer-model-b": {},
|
|
}
|
|
|
|
// make all models
|
|
for _, model := range response.Data {
|
|
modelID, ok := model["id"].(string)
|
|
assert.True(t, ok, "model ID should be a string")
|
|
_, exists := expectedModels[modelID]
|
|
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
|
delete(expectedModels, modelID)
|
|
|
|
object, ok := model["object"].(string)
|
|
assert.True(t, ok, "object should be a string")
|
|
assert.Equal(t, "model", object)
|
|
|
|
created, ok := model["created"].(float64)
|
|
assert.True(t, ok, "created should be a number")
|
|
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
|
|
|
ownedBy, ok := model["owned_by"].(string)
|
|
assert.True(t, ok, "owned_by should be a string")
|
|
assert.Equal(t, "llama-swap", ownedBy)
|
|
|
|
// check for optional name and description
|
|
if modelID == "model1" {
|
|
name, ok := model["name"].(string)
|
|
assert.True(t, ok, "name should be a string")
|
|
assert.Equal(t, "Model 1", name)
|
|
description, ok := model["description"].(string)
|
|
assert.True(t, ok, "description should be a string")
|
|
assert.Equal(t, "Model 1 description is used for testing", description)
|
|
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
|
|
// Peer models should have meta.llamaswap.peerID
|
|
meta, exists := model["meta"]
|
|
assert.True(t, exists, "peer model should have meta field")
|
|
metaMap, ok := meta.(map[string]interface{})
|
|
assert.True(t, ok, "meta should be a map")
|
|
llamaswap, exists := metaMap["llamaswap"]
|
|
assert.True(t, exists, "meta should have llamaswap field")
|
|
llamaswapMap, ok := llamaswap.(map[string]interface{})
|
|
assert.True(t, ok, "llamaswap should be a map")
|
|
peerID, exists := llamaswapMap["peerID"]
|
|
assert.True(t, exists, "llamaswap should have peerID field")
|
|
assert.Equal(t, "peer1", peerID)
|
|
} else {
|
|
_, exists := model["name"]
|
|
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
|
_, exists = model["description"]
|
|
assert.False(t, exists, "unexpected description field for model: %s", modelID)
|
|
}
|
|
}
|
|
|
|
// Ensure all expected models were returned
|
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
|
}
|
|
|
|
func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) {
|
|
// Process config through LoadConfigFromReader to apply macro substitution
|
|
configYaml := `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
startPort: 10000
|
|
models:
|
|
model1:
|
|
cmd: /path/to/server -p ${PORT}
|
|
macros:
|
|
PORT_NUM: 10001
|
|
TEMP: 0.7
|
|
NAME: "llama"
|
|
metadata:
|
|
port: ${PORT_NUM}
|
|
temperature: ${TEMP}
|
|
enabled: true
|
|
note: "Running on port ${PORT_NUM}"
|
|
nested:
|
|
value: ${TEMP}
|
|
model2:
|
|
cmd: /path/to/server -p ${PORT}
|
|
`
|
|
processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml))
|
|
assert.NoError(t, err)
|
|
|
|
proxy := New(processedConfig)
|
|
|
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response struct {
|
|
Data []map[string]any `json:"data"`
|
|
}
|
|
|
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, response.Data, 2)
|
|
|
|
// Find model1 and model2 in response
|
|
var model1Data, model2Data map[string]any
|
|
for _, model := range response.Data {
|
|
if model["id"] == "model1" {
|
|
model1Data = model
|
|
} else if model["id"] == "model2" {
|
|
model2Data = model
|
|
}
|
|
}
|
|
|
|
// Verify model1 has llamaswap_meta
|
|
assert.NotNil(t, model1Data)
|
|
meta, exists := model1Data["meta"]
|
|
if !assert.True(t, exists, "model1 should have meta key") {
|
|
t.FailNow()
|
|
}
|
|
|
|
metaMap := meta.(map[string]any)
|
|
|
|
lsmeta, exists := metaMap["llamaswap"]
|
|
if !assert.True(t, exists, "model1 should have meta.llamaswap key") {
|
|
t.FailNow()
|
|
}
|
|
|
|
lsmetamap := lsmeta.(map[string]any)
|
|
|
|
// Verify type preservation
|
|
assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64
|
|
assert.Equal(t, 0.7, lsmetamap["temperature"])
|
|
assert.Equal(t, true, lsmetamap["enabled"])
|
|
// Verify string interpolation
|
|
assert.Equal(t, "Running on port 10001", lsmetamap["note"])
|
|
// Verify nested structure
|
|
nested := lsmetamap["nested"].(map[string]any)
|
|
assert.Equal(t, 0.7, nested["value"])
|
|
|
|
// Verify model2 does NOT have llamaswap_meta
|
|
assert.NotNil(t, model2Data)
|
|
_, exists = model2Data["llamaswap_meta"]
|
|
assert.False(t, exists, "model2 should not have llamaswap_meta")
|
|
}
|
|
|
|
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
|
// Intentionally add models in non-sorted order and with an unlisted model
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
zeta:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond zeta
|
|
alpha:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond alpha
|
|
beta:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond beta
|
|
hidden:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond hidden
|
|
unlisted: true
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
|
|
// Request models list
|
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response struct {
|
|
Data []map[string]interface{} `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// We expect only the listed models in sorted order by id
|
|
expectedOrder := []string{"alpha", "beta", "zeta"}
|
|
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
|
|
got := make([]string, 0, len(response.Data))
|
|
for _, m := range response.Data {
|
|
id, _ := m["id"].(string)
|
|
got = append(got, id)
|
|
}
|
|
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
|
// Configure alias
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
includeAliasesInList: true
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
name: "Model 1"
|
|
aliases:
|
|
- alias1
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
|
|
// Request models list
|
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response struct {
|
|
Data []map[string]interface{} `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// We expect both base id and alias
|
|
var model1Data, alias1Data map[string]any
|
|
for _, model := range response.Data {
|
|
if model["id"] == "model1" {
|
|
model1Data = model
|
|
} else if model["id"] == "alias1" {
|
|
alias1Data = model
|
|
}
|
|
}
|
|
|
|
// Verify model1 has name
|
|
assert.NotNil(t, model1Data)
|
|
_, exists := model1Data["name"]
|
|
if !assert.True(t, exists, "model1 should have name key") {
|
|
t.FailNow()
|
|
}
|
|
name1, ok := model1Data["name"].(string)
|
|
assert.True(t, ok, "name1 should be a string")
|
|
|
|
// Verify alias1 has name
|
|
assert.NotNil(t, alias1Data)
|
|
_, exists = alias1Data["name"]
|
|
if !assert.True(t, exists, "alias1 should have name key") {
|
|
t.FailNow()
|
|
}
|
|
name2, ok := alias1Data["name"].(string)
|
|
assert.True(t, ok, "name2 should be a string")
|
|
|
|
// Name keys should match
|
|
assert.Equal(t, name1, name2)
|
|
}
|
|
|
|
func TestProxyManager_Shutdown(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping slow test")
|
|
}
|
|
|
|
// make broken model configurations
|
|
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
|
model1Config.Proxy = "http://localhost:10001/"
|
|
|
|
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
|
model2Config.Proxy = "http://localhost:10002/"
|
|
|
|
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
|
model3Config.Proxy = "http://localhost:10003/"
|
|
|
|
cfg := config.AddDefaultGroupToConfig(config.Config{
|
|
HealthCheckTimeout: 15,
|
|
Models: map[string]config.ModelConfig{
|
|
"model1": model1Config,
|
|
"model2": model2Config,
|
|
"model3": model3Config,
|
|
},
|
|
LogLevel: "error",
|
|
Groups: map[string]config.GroupConfig{
|
|
"test": {
|
|
Swap: false,
|
|
Members: []string{"model1", "model2", "model3"},
|
|
},
|
|
},
|
|
})
|
|
|
|
proxy := New(cfg)
|
|
|
|
// Start all the processes
|
|
var wg sync.WaitGroup
|
|
for _, modelName := range []string{"model1", "model2", "model3"} {
|
|
wg.Add(1)
|
|
go func(modelName string) {
|
|
defer wg.Done()
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
|
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
|
}(modelName)
|
|
}
|
|
|
|
go func() {
|
|
<-time.After(time.Second)
|
|
proxy.Shutdown()
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestProxyManager_Unload(t *testing.T) {
|
|
conf := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(conf)
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
|
req = httptest.NewRequest("GET", "/unload", nil)
|
|
w = CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, w.Body.String(), "OK")
|
|
|
|
select {
|
|
case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
|
// good
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for model1 to stop")
|
|
}
|
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
|
}
|
|
|
|
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
|
const testGroupId = "testGroup"
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
groups:
|
|
testGroup:
|
|
swap: false
|
|
members:
|
|
- model1
|
|
- model2
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
|
|
// start both model
|
|
for _, modelName := range []string{"model1", "model2"} {
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
}
|
|
|
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
|
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
|
|
|
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
if !assert.Equal(t, w.Body.String(), "OK") {
|
|
t.FailNow()
|
|
}
|
|
|
|
select {
|
|
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
|
|
// good
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for model1 to stop")
|
|
}
|
|
|
|
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
|
|
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
|
|
}
|
|
|
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|
// Shared configuration
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: warn
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
`)
|
|
|
|
// Define a helper struct to parse the JSON response.
|
|
type RunningResponse struct {
|
|
Running []struct {
|
|
Model string `json:"model"`
|
|
State string `json:"state"`
|
|
Cmd string `json:"cmd"`
|
|
Proxy string `json:"proxy"`
|
|
TTL int `json:"ttl"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
} `json:"running"`
|
|
}
|
|
|
|
// Create proxy once for all tests
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
t.Run("no models loaded", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/running", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response RunningResponse
|
|
|
|
// Check if this is a valid JSON object.
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
|
|
// We should have an empty running array here.
|
|
assert.Empty(t, response.Running, "expected no running models")
|
|
})
|
|
|
|
t.Run("single model loaded", func(t *testing.T) {
|
|
// Load just a model.
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
// Simulate browser call for the `/running` endpoint.
|
|
req = httptest.NewRequest("GET", "/running", nil)
|
|
w = CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
var response RunningResponse
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
|
|
// Check if we have a single array element.
|
|
assert.Len(t, response.Running, 1)
|
|
|
|
// Is this the right model?
|
|
assert.Equal(t, "model1", response.Running[0].Model)
|
|
|
|
// Is the model loaded?
|
|
assert.Equal(t, "ready", response.Running[0].State)
|
|
|
|
// Verify extended fields are present
|
|
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
|
|
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
|
|
assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to globalTTL (0)")
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
TheExpectedModel:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
// Create a buffer with multipart form data
|
|
var b bytes.Buffer
|
|
w := multipart.NewWriter(&b)
|
|
|
|
// Add the model field
|
|
fw, err := w.CreateFormField("model")
|
|
assert.NoError(t, err)
|
|
_, err = fw.Write([]byte("TheExpectedModel"))
|
|
assert.NoError(t, err)
|
|
|
|
// Add a file field
|
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
|
assert.NoError(t, err)
|
|
// Generate random content length between 10 and 20
|
|
contentLength := rand.Intn(11) + 10 // 10 to 20
|
|
content := make([]byte, contentLength)
|
|
_, err = fw.Write(content)
|
|
assert.NoError(t, err)
|
|
w.Close()
|
|
|
|
// Create the request with the multipart form data
|
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
rec := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
// Verify the response
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
var response map[string]string
|
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "TheExpectedModel", response["model"])
|
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
|
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
|
}
|
|
|
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
|
func TestProxyManager_UseModelName(t *testing.T) {
|
|
upstreamModelName := "upstreamModel"
|
|
|
|
conf := testConfigFromYAML(t, fmt.Sprintf(`
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond %s
|
|
useModelName: %s
|
|
`, upstreamModelName, upstreamModelName))
|
|
|
|
proxy := New(conf)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
requestedModel := "model1"
|
|
|
|
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), upstreamModelName)
|
|
|
|
// make sure the content length was set correctly
|
|
// simple-responder will return the content length it got in the response
|
|
body := w.Body.Bytes()
|
|
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
|
|
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
|
|
})
|
|
|
|
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
|
// Create a buffer with multipart form data
|
|
var b bytes.Buffer
|
|
w := multipart.NewWriter(&b)
|
|
|
|
// Add the model field
|
|
fw, err := w.CreateFormField("model")
|
|
assert.NoError(t, err)
|
|
_, err = fw.Write([]byte(requestedModel))
|
|
assert.NoError(t, err)
|
|
|
|
// Add a file field
|
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
|
assert.NoError(t, err)
|
|
_, err = fw.Write([]byte("test"))
|
|
assert.NoError(t, err)
|
|
w.Close()
|
|
|
|
// Create the request with the multipart form data
|
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
rec := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
// Verify the response
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
var response map[string]string
|
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, upstreamModelName, response["model"])
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_AudioVoicesGETHandler(t *testing.T) {
|
|
conf := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(conf)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
t.Run("successful GET with model query param", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "voice1")
|
|
})
|
|
|
|
t.Run("missing model query param returns 400", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/v1/audio/voices", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
|
})
|
|
|
|
t.Run("unknown model returns 400", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/v1/audio/voices?model=nonexistent", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
requestHeaders map[string]string
|
|
expectedStatus int
|
|
expectedHeaders map[string]string
|
|
}{
|
|
{
|
|
name: "OPTIONS with no headers",
|
|
method: "OPTIONS",
|
|
expectedStatus: http.StatusNoContent,
|
|
expectedHeaders: map[string]string{
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
|
},
|
|
},
|
|
{
|
|
name: "OPTIONS with specific headers",
|
|
method: "OPTIONS",
|
|
requestHeaders: map[string]string{
|
|
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
|
},
|
|
expectedStatus: http.StatusNoContent,
|
|
expectedHeaders: map[string]string{
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
|
},
|
|
},
|
|
{
|
|
name: "Non-OPTIONS request",
|
|
method: "GET",
|
|
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
|
for k, v := range tt.requestHeaders {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
|
|
|
for header, expectedValue := range tt.expectedHeaders {
|
|
assert.Equal(t, expectedValue, w.Header().Get(header))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_Upstream(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
aliases: [model-alias]
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
t.Run("main model name", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
|
rec := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "model1", rec.Body.String())
|
|
})
|
|
|
|
t.Run("model alias", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
|
rec := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "model1", rec.Body.String())
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
var response map[string]interface{}
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
assert.Equal(t, "81", response["h_content_length"])
|
|
assert.Equal(t, "model1", response["responseMessage"])
|
|
}
|
|
|
|
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
filters:
|
|
stripParams: "temperature, model, stream"
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
var response map[string]interface{}
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
|
|
// `temperature` and `stream` are gone but model remains
|
|
assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])
|
|
|
|
// assert.Nil(t, response["temperature"])
|
|
// assert.Equal(t, "123", response["x_param"])
|
|
// assert.Equal(t, "abc", response["y_param"])
|
|
// t.Logf("%v", response)
|
|
}
|
|
|
|
func TestProxyManager_FiltersSetParamsByID(t *testing.T) {
|
|
// no explicit aliases — setParamsByID keys are auto-registered as aliases
|
|
cfg := testConfigFromYAML(t, `
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
proxy: "http://127.0.0.1:${PORT}"
|
|
filters:
|
|
setParams:
|
|
reasoning_effort: medium
|
|
setParamsByID:
|
|
"${MODEL_ID}:high":
|
|
reasoning_effort: high
|
|
"${MODEL_ID}:low":
|
|
reasoning_effort: low
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
tests := []struct {
|
|
requestedModel string
|
|
wantEffort string
|
|
}{
|
|
// setParams applies, no setParamsByID match
|
|
{requestedModel: "model1", wantEffort: "medium"},
|
|
// setParamsByID overrides setParams
|
|
{requestedModel: "model1:high", wantEffort: "high"},
|
|
{requestedModel: "model1:low", wantEffort: "low"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.requestedModel, func(t *testing.T) {
|
|
reqBody := fmt.Sprintf(`{"model":%q}`, tt.requestedModel)
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var response map[string]interface{}
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
|
|
requestBody, _ := response["request_body"].(string)
|
|
gotEffort := gjson.Get(requestBody, "reasoning_effort").String()
|
|
assert.Equal(t, tt.wantEffort, gotEffort, "reasoning_effort mismatch for model %s", tt.requestedModel)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
req := httptest.NewRequest("GET", "/health", nil)
|
|
rec := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "OK", rec.Body.String())
|
|
}
|
|
|
|
// Ensure the custom llama-server /completion endpoint proxies correctly
|
|
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "model1")
|
|
}
|
|
|
|
func TestProxyManager_StartupHooks(t *testing.T) {
|
|
|
|
cfg := testConfigFromYAML(t, `
|
|
logLevel: error
|
|
hooks:
|
|
on_startup:
|
|
preload:
|
|
- model1
|
|
- model2
|
|
groups:
|
|
preloadTestGroup:
|
|
swap: false
|
|
members:
|
|
- model1
|
|
- model2
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
model2:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model2
|
|
`)
|
|
|
|
preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events
|
|
|
|
unsub := event.On(func(e ModelPreloadedEvent) {
|
|
preloadChan <- e
|
|
})
|
|
|
|
defer unsub()
|
|
|
|
// Create the proxy which should trigger preloading
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
for i := 0; i < 2; i++ {
|
|
select {
|
|
case <-preloadChan:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timed out waiting for models to preload")
|
|
}
|
|
}
|
|
// make sure they are both loaded
|
|
_, foundGroup := proxy.processGroups["preloadTestGroup"]
|
|
if !assert.True(t, foundGroup, "preloadTestGroup should exist") {
|
|
return
|
|
}
|
|
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState())
|
|
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState())
|
|
}
|
|
|
|
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
author/model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond author/model
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
endpoints := []string{
|
|
"/api/events",
|
|
"/logs/stream",
|
|
"/logs/stream/proxy",
|
|
"/logs/stream/upstream",
|
|
"/logs/stream/author/model",
|
|
}
|
|
|
|
for _, endpoint := range endpoints {
|
|
t.Run(endpoint, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
req := httptest.NewRequest("GET", endpoint, nil)
|
|
req = req.WithContext(ctx)
|
|
rec := CreateTestResponseRecorder()
|
|
|
|
// Run handler in goroutine and wait for context timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
proxy.ServeHTTP(rec, req)
|
|
}()
|
|
|
|
// Wait for either the handler to complete or context to timeout
|
|
<-ctx.Done()
|
|
|
|
// At this point, the handler has either finished or been cancelled
|
|
// Wait for the goroutine to fully exit before reading
|
|
<-done
|
|
|
|
// Now it's safe to read from rec - no more concurrent writes
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
streaming-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond streaming-model
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
// Make a streaming request
|
|
reqBody := `{"model":"streaming-model"}`
|
|
// simple-responder will return text/event-stream when stream=true is in the query
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
|
rec := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
|
assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
|
}
|
|
|
|
func TestProxyManager_ApiGetVersion(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
// Version test map
|
|
versionTest := map[string]string{
|
|
"build_date": "1970-01-01T00:00:00Z",
|
|
"commit": "cc915ddb6f04a42d9cd1f524e1d46ec6ed069fdc",
|
|
"version": "v001",
|
|
}
|
|
|
|
proxy := New(cfg)
|
|
proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"])
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
req := httptest.NewRequest("GET", "/api/version", nil)
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
|
|
// Ensure json response
|
|
assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))
|
|
|
|
// Check for attributes
|
|
response := map[string]string{}
|
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
for key, value := range versionTest {
|
|
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
|
}
|
|
}
|
|
|
|
func TestProxyManager_APIKeyAuth(t *testing.T) {
|
|
testConfig := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
apiKeys:
|
|
- valid-key-1
|
|
- valid-key-2
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
t.Run("valid key in x-api-key header", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
req.Header.Set("x-api-key", "valid-key-1")
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
|
|
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
req.Header.Set("Authorization", "Bearer valid-key-2")
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
|
|
t.Run("both headers with matching keys", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
req.Header.Set("x-api-key", "valid-key-1")
|
|
req.Header.Set("Authorization", "Bearer valid-key-1")
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
|
|
t.Run("invalid key returns 401", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
req.Header.Set("x-api-key", "invalid-key")
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Contains(t, w.Body.String(), "unauthorized")
|
|
})
|
|
|
|
t.Run("missing key returns 401", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
})
|
|
|
|
t.Run("valid key in Basic Auth header", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
// Basic Auth: base64("anyuser:valid-key-1")
|
|
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
|
|
req.Header.Set("Authorization", "Basic "+credentials)
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
|
|
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
|
|
req.Header.Set("Authorization", "Basic "+credentials)
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Contains(t, w.Body.String(), "unauthorized")
|
|
})
|
|
|
|
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
req.Header.Set("x-api-key", "valid-key-1")
|
|
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
|
|
req.Header.Set("Authorization", "Basic "+credentials)
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
|
|
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
|
// Config without RequiredAPIKeys - auth should be disabled
|
|
testConfig := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
t.Run("requests pass without API key when not configured", func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
}
|
|
|
|
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
|
|
// in proxyInferenceHandler for issue #433
|
|
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
|
|
t.Run("requests to peer models are proxied", func(t *testing.T) {
|
|
// Create a test server to act as the peer
|
|
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
|
|
}))
|
|
defer peerServer.Close()
|
|
|
|
testConfig := testConfigFromYAML(t, fmt.Sprintf(`
|
|
logLevel: error
|
|
peers:
|
|
test-peer:
|
|
proxy: %s
|
|
models:
|
|
- peer-model
|
|
models:
|
|
local-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
|
`, peerServer.URL))
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
reqBody := `{"model":"peer-model"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "from-peer")
|
|
})
|
|
|
|
t.Run("local models take precedence over peer models", func(t *testing.T) {
|
|
// Create a test server to act as the peer - should NOT be called
|
|
peerCalled := false
|
|
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
peerCalled = true
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"response":"from-peer"}`))
|
|
}))
|
|
defer peerServer.Close()
|
|
|
|
testConfig := testConfigFromYAML(t, fmt.Sprintf(`
|
|
logLevel: error
|
|
peers:
|
|
test-peer:
|
|
proxy: %s
|
|
models:
|
|
- shared-model
|
|
models:
|
|
shared-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-response
|
|
`, peerServer.URL))
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, map[string]string{"shared-model": "local-response"})
|
|
|
|
reqBody := `{"model":"shared-model"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "local-response")
|
|
assert.False(t, peerCalled, "peer should not be called when local model exists")
|
|
})
|
|
|
|
t.Run("unknown model returns error", func(t *testing.T) {
|
|
// Create a test server to act as the peer
|
|
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer peerServer.Close()
|
|
|
|
testConfig := testConfigFromYAML(t, fmt.Sprintf(`
|
|
logLevel: error
|
|
peers:
|
|
test-peer:
|
|
proxy: %s
|
|
models:
|
|
- peer-model
|
|
models:
|
|
local-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
|
`, peerServer.URL))
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
reqBody := `{"model":"unknown-model"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
|
})
|
|
|
|
t.Run("peer API key is injected into request", func(t *testing.T) {
|
|
var receivedAuthHeader string
|
|
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
receivedAuthHeader = r.Header.Get("Authorization")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"response":"ok"}`))
|
|
}))
|
|
defer peerServer.Close()
|
|
|
|
testConfig := testConfigFromYAML(t, fmt.Sprintf(`
|
|
logLevel: error
|
|
peers:
|
|
test-peer:
|
|
proxy: %s
|
|
apiKey: secret-peer-key
|
|
models:
|
|
- peer-model
|
|
models:
|
|
local-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
|
`, peerServer.URL))
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
reqBody := `{"model":"peer-model"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
|
|
})
|
|
|
|
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
|
|
testConfig := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
local-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
|
`)
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
// peerProxy exists but has no peer models configured
|
|
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
|
|
|
|
reqBody := `{"model":"unknown-model"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
|
})
|
|
|
|
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
|
|
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("data: test\n\n"))
|
|
}))
|
|
defer peerServer.Close()
|
|
|
|
testConfig := testConfigFromYAML(t, fmt.Sprintf(`
|
|
logLevel: error
|
|
peers:
|
|
test-peer:
|
|
proxy: %s
|
|
models:
|
|
- peer-model
|
|
models:
|
|
local-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
|
`, peerServer.URL))
|
|
|
|
proxy := New(testConfig)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
reqBody := `{"model":"peer-model"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) {
|
|
conf := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
sd-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond sd-model
|
|
`)
|
|
|
|
proxy := New(conf)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
t.Run("successful txt2img with model", func(t *testing.T) {
|
|
reqBody := `{"model":"sd-model","prompt":"a cat"}`
|
|
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "sd-model")
|
|
})
|
|
|
|
t.Run("successful img2img with model", func(t *testing.T) {
|
|
reqBody := `{"model":"sd-model","prompt":"a cat","init_images":[]}`
|
|
req := httptest.NewRequest("POST", "/sdapi/v1/img2img", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "sd-model")
|
|
})
|
|
|
|
t.Run("missing model returns 400", func(t *testing.T) {
|
|
reqBody := `{"prompt":"a cat"}`
|
|
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "missing or invalid 'model' key")
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_SdApiGetLoras(t *testing.T) {
|
|
conf := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
sd-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond sd-model
|
|
`)
|
|
|
|
proxy := New(conf)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
t.Run("successful GET loras with model query param", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=sd-model", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
})
|
|
|
|
t.Run("missing model query param returns 400", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/sdapi/v1/loras", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
|
})
|
|
|
|
t.Run("unknown model returns 400", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=nonexistent", nil)
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_AudioTranscriptionCapture(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
captureBuffer: 5
|
|
models:
|
|
TheExpectedModel:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond TheExpectedModel
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
var b bytes.Buffer
|
|
w := multipart.NewWriter(&b)
|
|
|
|
fw, err := w.CreateFormField("model")
|
|
assert.NoError(t, err)
|
|
_, err = fw.Write([]byte("TheExpectedModel"))
|
|
assert.NoError(t, err)
|
|
|
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
|
assert.NoError(t, err)
|
|
_, err = fw.Write([]byte("test audio content"))
|
|
assert.NoError(t, err)
|
|
w.Close()
|
|
|
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
req.Header.Set("Authorization", "Bearer mysecret")
|
|
req.Header.Set("X-Custom-Req", "req-value")
|
|
rec := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
// Verify capture exists
|
|
metrics := proxy.metricsMonitor.getMetrics()
|
|
assert.Equal(t, 1, len(metrics))
|
|
assert.True(t, metrics[0].HasCapture)
|
|
|
|
capture := proxy.metricsMonitor.getCaptureByID(metrics[0].ID)
|
|
assert.NotNil(t, capture)
|
|
|
|
// Should capture request headers (sensitive ones redacted)
|
|
assert.NotEmpty(t, capture.ReqHeaders)
|
|
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
|
assert.Equal(t, "req-value", capture.ReqHeaders["X-Custom-Req"])
|
|
|
|
// Should capture response headers
|
|
assert.NotNil(t, capture.RespHeaders)
|
|
|
|
// Should NOT capture request bodies but get response bodies (text
|
|
assert.Nil(t, capture.ReqBody)
|
|
assert.NotNil(t, capture.RespBody)
|
|
}
|
|
|
|
func TestProxyManager_VersionlessEndpoints_LocalModel(t *testing.T) {
|
|
cfg := testConfigFromYAML(t, `
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
models:
|
|
model1:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
|
`)
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
injectTestHandlers(proxy, nil)
|
|
|
|
endpoints := []string{
|
|
"/v/chat/completions",
|
|
"/v/responses",
|
|
"/v/completions",
|
|
"/v/embeddings",
|
|
"/v/rerank",
|
|
"/v/reranking",
|
|
}
|
|
|
|
for _, endpoint := range endpoints {
|
|
t.Run(endpoint, func(t *testing.T) {
|
|
reqBody := `{"model":"model1"}`
|
|
req := httptest.NewRequest("POST", endpoint, bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "model1")
|
|
})
|
|
}
|
|
|
|
t.Run("/v/messages", func(t *testing.T) {
|
|
reqBody := `{"model":"model1","messages":[{"role":"user","content":"hi"}]}`
|
|
req := httptest.NewRequest("POST", "/v/messages", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "model1")
|
|
})
|
|
}
|
|
|
|
func TestProxyManager_VersionlessEndpoints_PeerModel(t *testing.T) {
|
|
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprintf(w, `{"endpoint":"%s","model":"peer-model"}`, r.URL.Path)
|
|
}))
|
|
defer peerServer.Close()
|
|
|
|
cfg := testConfigFromYAML(t, fmt.Sprintf(`
|
|
healthCheckTimeout: 15
|
|
logLevel: error
|
|
peers:
|
|
test-peer:
|
|
proxy: %s
|
|
models:
|
|
- peer-model
|
|
models:
|
|
local-model:
|
|
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
|
`, peerServer.URL))
|
|
|
|
proxy := New(cfg)
|
|
defer proxy.StopProcesses(StopImmediately)
|
|
|
|
endpoints := []struct {
|
|
path string
|
|
wantSuffix string
|
|
}{
|
|
{"/v/chat/completions", "/chat/completions"},
|
|
{"/v/responses", "/responses"},
|
|
{"/v/completions", "/completions"},
|
|
{"/v/embeddings", "/embeddings"},
|
|
{"/v/rerank", "/rerank"},
|
|
{"/v/reranking", "/reranking"},
|
|
}
|
|
|
|
for _, ep := range endpoints {
|
|
t.Run(ep.path, func(t *testing.T) {
|
|
reqBody := `{"model":"peer-model"}`
|
|
req := httptest.NewRequest("POST", ep.path, bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), ep.wantSuffix)
|
|
})
|
|
}
|
|
|
|
t.Run("/v/messages", func(t *testing.T) {
|
|
reqBody := `{"model":"peer-model","messages":[{"role":"user","content":"hi"}]}`
|
|
req := httptest.NewRequest("POST", "/v/messages", bytes.NewBufferString(reqBody))
|
|
w := CreateTestResponseRecorder()
|
|
proxy.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "/messages")
|
|
})
|
|
}
|