Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc |
@@ -23,6 +23,7 @@ Written in golang, it is very easy to install (single binary with no dependencie
|
|||||||
- ✅ llama-server (llama.cpp) supported endpoints:
|
- ✅ llama-server (llama.cpp) supported endpoints:
|
||||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||||
- `/infill` - for code infilling
|
- `/infill` - for code infilling
|
||||||
|
- `/completion` - for completion endpoint
|
||||||
- ✅ llama-swap custom API endpoints
|
- ✅ llama-swap custom API endpoints
|
||||||
- `/ui` - web UI
|
- `/ui` - web UI
|
||||||
- `/log` - remote log monitoring
|
- `/log` - remote log monitoring
|
||||||
|
|||||||
@@ -153,6 +153,19 @@ func main() {
|
|||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// llama-server compatibility: /completion
|
||||||
|
r.POST("/completion", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||||
// Parse the multipart form
|
// Parse the multipart form
|
||||||
|
|||||||
+12
-1
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -363,8 +364,18 @@ func (p *Process) stopCommand() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 500 * time.Millisecond,
|
// wait a short time for a tcp connection to be established
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
}).DialContext,
|
||||||
|
},
|
||||||
|
|
||||||
|
// give a long time to respond to the health check endpoint
|
||||||
|
// after the connection is established. See issue: 276
|
||||||
|
Timeout: 5000 * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
|
|||||||
@@ -60,10 +60,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
if pg.swap {
|
if pg.swap {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
if pg.lastUsedProcess != modelID {
|
if pg.lastUsedProcess != modelID {
|
||||||
|
|
||||||
|
// is there something already running?
|
||||||
if pg.lastUsedProcess != "" {
|
if pg.lastUsedProcess != "" {
|
||||||
pg.processes[pg.lastUsedProcess].Stop()
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for the request to the new model to be fully handled
|
||||||
|
// and prevent race conditions see issue #277
|
||||||
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
pg.lastUsedProcess = modelID
|
pg.lastUsedProcess = modelID
|
||||||
|
|
||||||
|
// short circuit and exit
|
||||||
|
pg.Unlock()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
}
|
}
|
||||||
|
|||||||
+34
-16
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -44,32 +45,49 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
assert.False(t, pg.HasMember("model3"))
|
assert.False(t, pg.HasMember("model3"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
// use the same listening so if a model is already running, it will fail
|
||||||
|
// this is a way to test that swap isolation is working
|
||||||
|
// properly when there are parallel requests made at the
|
||||||
|
// same time.
|
||||||
|
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||||
|
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||||
|
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||||
|
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||||
|
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(len(tests))
|
||||||
for _, modelName := range tests {
|
for _, modelName := range tests {
|
||||||
t.Run(modelName, func(t *testing.T) {
|
go func(modelName string) {
|
||||||
reqBody := `{"x", "y"}`
|
defer wg.Done()
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
}(modelName)
|
||||||
// make sure only one process is in the running state
|
|
||||||
count := 0
|
|
||||||
for _, process := range pg.processes {
|
|
||||||
if process.CurrentState() == StateReady {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.Equal(t, 1, count)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
|
|||||||
@@ -203,6 +203,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
// llama-server's /infill endpoint for code infilling
|
// llama-server's /infill endpoint for code infilling
|
||||||
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /completion endpoint
|
||||||
|
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
@@ -834,6 +833,28 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
|
|||||||
assert.Equal(t, "OK", rec.Body.String())
|
assert.Equal(t, "OK", rec.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the custom llama-server /completion endpoint proxies correctly
|
||||||
|
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "model1")
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyManager_StartupHooks(t *testing.T) {
|
func TestProxyManager_StartupHooks(t *testing.T) {
|
||||||
|
|
||||||
// using real YAML as the configuration has gotten more complex
|
// using real YAML as the configuration has gotten more complex
|
||||||
|
|||||||
Reference in New Issue
Block a user