Compare commits

...

5 Commits

Author SHA1 Message Date
Benson Wong 831a90d3b0 Add different timeout scenarios to Process.checkHealthEndpoint #276 (#278)
- add a TCP connection timeout of 500ms
- increase HTTP client timeout to 5000ms

In this new behaviour the upstream has 500ms to accept a tcp connection
and 5000ms to respond to the HTTP request.
2025-08-28 22:03:14 -07:00
Yandrik 977f1856bb add /completion endpoint (#275)
* feat: add /completion endpoint
* chore: reformat using gofmt
2025-08-28 21:41:02 -07:00
Benson Wong 52b329f7bc Fix #277 race condition in ProcessGroup.ProxyRequest when swap=true 2025-08-28 21:38:40 -07:00
Benson Wong 57803fd3aa Support llama-server's /infill endpoint (#272)
Add support for llama-server's /infill endpoint and metrics gathering on the Activities page.
2025-08-27 08:36:05 -07:00
Benson Wong c55d0cc842 Add docs for model.concurrencyLimit #263 [skip ci] 2025-08-22 16:08:37 -07:00
9 changed files with 143 additions and 43 deletions
+4 -1
View File
@@ -18,9 +18,12 @@ Written in golang, it is very easy to install (single binary with no dependencie
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/rerank`, `v1/reranking`, `rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-server (llama.cpp) supported endpoints:
- `v1/rerank`, `v1/reranking`, `/rerank`
- `/infill` - for code infilling
- `/completion` - for completion endpoint
- ✅ llama-swap custom API endpoints
- `/ui` - web UI
- `/log` - remote log monitoring
+9
View File
@@ -129,6 +129,15 @@ models:
# - recommended to stick to sampling parameters
strip_params: "temperature, top_p, top_k"
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0
# - useful for limiting the number of active parallel requests a model can process
# - must be set per model
# - any number greater than 0 will override the internal default value of 10
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
# - recommended to be omitted and the default used
concurrencyLimit: 0
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
+13
View File
@@ -153,6 +153,19 @@ func main() {
})
// llama-server compatibility: /completion
r.POST("/completion", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
})
})
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
+28 -22
View File
@@ -5,12 +5,20 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
// isStreaming bool
startTime time.Time
}
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
@@ -41,49 +49,47 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
rec := writer.metricsRecorder
rec.processBody(writer.body)
}
}
// check for streaming response
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
writer.metricsRecorder.processStreamingResponse(writer.body)
} else {
writer.metricsRecorder.processNonStreamingResponse(writer.body)
}
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
isStreaming bool
startTime time.Time
}
// processBody handles response processing after request completes
func (rec *MetricsRecorder) processBody(body []byte) {
if rec.isStreaming {
rec.processStreamingResponse(body)
} else {
rec.processNonStreamingResponse(body)
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
if !usage.Exists() {
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return false
}
// default values
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings := jsonData.Get("timings"); timings.Exists() {
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
+12 -1
View File
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
@@ -363,8 +364,18 @@ func (p *Process) stopCommand() {
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
// wait a short time for a tcp connection to be established
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 500 * time.Millisecond,
}).DialContext,
},
// give a long time to respond to the health check endpoint
// after the connection is established. See issue: 276
Timeout: 5000 * time.Millisecond,
}
req, err := http.NewRequest("GET", healthURL, nil)
+10
View File
@@ -60,10 +60,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
// is there something already running?
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
// wait for the request to the new model to be fully handled
// and prevent race conditions see issue #277
pg.processes[modelID].ProxyRequest(writer, request)
pg.lastUsedProcess = modelID
// short circuit and exit
pg.Unlock()
return nil
}
pg.Unlock()
}
+34 -16
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -44,32 +45,49 @@ func TestProcessGroup_HasMember(t *testing.T) {
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
// use the same listening so if a model is already running, it will fail
// this is a way to test that swap isolation is working
// properly when there are parallel requests made at the
// same time.
"model1": getTestSimpleResponderConfigPort("model1", 9832),
"model2": getTestSimpleResponderConfigPort("model2", 9832),
"model3": getTestSimpleResponderConfigPort("model3", 9832),
"model4": getTestSimpleResponderConfigPort("model4", 9832),
"model5": getTestSimpleResponderConfigPort("model5", 9832),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2", "model3", "model4", "model5"},
},
},
})
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"}
tests := []string{"model1", "model2", "model3", "model4", "model5"}
var wg sync.WaitGroup
wg.Add(len(tests))
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
go func(modelName string) {
defer wg.Done()
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}(modelName)
}
wg.Wait()
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
+11 -2
View File
@@ -191,11 +191,20 @@ func (pm *ProxyManager) setupGinEngine() {
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
// Support embeddings
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
+22 -1
View File
@@ -42,7 +42,6 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
assert.Contains(t, w.Body.String(), modelName)
}
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
@@ -834,6 +833,28 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
assert.Equal(t, "OK", rec.Body.String())
}
// Ensure the custom llama-server /completion endpoint proxies correctly
func TestProxyManager_CompletionEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
func TestProxyManager_StartupHooks(t *testing.T) {
// using real YAML as the configuration has gotten more complex