proxy: add versionless API endpoint (#733)
Add versionless endpoints under v/ to support upstream peers that do not use the v1/ prefix. Fixes #728.
This commit is contained in:
committed by
GitHub
parent
11b7913287
commit
e261745c66
+31
-20
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -125,6 +126,22 @@ func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||
// It supports the endpoints that routing tests depend on, without launching
|
||||
// any subprocess or binding any port.
|
||||
func respondJSON(w http.ResponseWriter, respond string, bodyBytes []byte) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"h_content_length": strconv.Itoa(len(bodyBytes)),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newTestHandler(respond string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
@@ -170,19 +187,7 @@ func newTestHandler(respond string) http.Handler {
|
||||
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"h_content_length": r.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -198,15 +203,21 @@ func newTestHandler(respond string) http.Handler {
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
})
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
|
||||
for _, path := range []string{
|
||||
"/chat/completions", "/completions",
|
||||
"/responses", "/messages", "/messages/count_tokens",
|
||||
"/embeddings", "/rerank", "/reranking",
|
||||
} {
|
||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
respondJSON(w, respond, bodyBytes)
|
||||
})
|
||||
}
|
||||
|
||||
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
|
||||
@@ -351,6 +351,16 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// Unversioned API endpoints, see issue #728
|
||||
pm.ginEngine.POST("/v/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
pm.ginEngine.POST("/v/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
|
||||
|
||||
@@ -860,6 +870,11 @@ func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context)
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// issue #728 support versionless API requests
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v/") {
|
||||
c.Request.URL.Path = strings.TrimPrefix(c.Request.URL.Path, "/v")
|
||||
}
|
||||
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
|
||||
@@ -1779,3 +1779,103 @@ models:
|
||||
assert.Nil(t, capture.ReqBody)
|
||||
assert.NotNil(t, capture.RespBody)
|
||||
}
|
||||
|
||||
func TestProxyManager_VersionlessEndpoints_LocalModel(t *testing.T) {
|
||||
cfg := testConfigFromYAML(t, `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: {{RESPONDER}} --port ${PORT} --silent --respond model1
|
||||
`)
|
||||
|
||||
proxy := New(cfg)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
injectTestHandlers(proxy, nil)
|
||||
|
||||
endpoints := []string{
|
||||
"/v/chat/completions",
|
||||
"/v/responses",
|
||||
"/v/completions",
|
||||
"/v/embeddings",
|
||||
"/v/rerank",
|
||||
"/v/reranking",
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
t.Run(endpoint, func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", endpoint, bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("/v/messages", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1","messages":[{"role":"user","content":"hi"}]}`
|
||||
req := httptest.NewRequest("POST", "/v/messages", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_VersionlessEndpoints_PeerModel(t *testing.T) {
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"endpoint":"%s","model":"peer-model"}`, r.URL.Path)
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
cfg := testConfigFromYAML(t, fmt.Sprintf(`
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: {{RESPONDER}} --port ${PORT} --silent --respond local-model
|
||||
`, peerServer.URL))
|
||||
|
||||
proxy := New(cfg)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
endpoints := []struct {
|
||||
path string
|
||||
wantSuffix string
|
||||
}{
|
||||
{"/v/chat/completions", "/chat/completions"},
|
||||
{"/v/responses", "/responses"},
|
||||
{"/v/completions", "/completions"},
|
||||
{"/v/embeddings", "/embeddings"},
|
||||
{"/v/rerank", "/rerank"},
|
||||
{"/v/reranking", "/reranking"},
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
t.Run(ep.path, func(t *testing.T) {
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", ep.path, bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), ep.wantSuffix)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("/v/messages", func(t *testing.T) {
|
||||
reqBody := `{"model":"peer-model","messages":[{"role":"user","content":"hi"}]}`
|
||||
req := httptest.NewRequest("POST", "/v/messages", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "/messages")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user