Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a15e47922c |
+1
-27
@@ -314,7 +314,7 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||||
upstreamPath := r.PathValue("upstreamPath")
|
upstreamPath := r.PathValue("upstreamPath")
|
||||||
|
|
||||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
|
||||||
if !found {
|
if !found {
|
||||||
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
@@ -349,29 +349,3 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// findModelInPath walks a slash-separated path, building up segments until one
|
|
||||||
// matches a configured model. This resolves model names that contain slashes
|
|
||||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
|
||||||
// remaining path, and whether a match was found.
|
|
||||||
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
|
||||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
|
||||||
name := ""
|
|
||||||
|
|
||||||
for i, part := range parts {
|
|
||||||
if part == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
name = part
|
|
||||||
} else {
|
|
||||||
name = name + "/" + part
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelID, ok := cfg.RealModelName(name); ok {
|
|
||||||
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", "", "", false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,11 +2,15 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_HandleListModels(t *testing.T) {
|
func TestServer_HandleListModels(t *testing.T) {
|
||||||
@@ -78,6 +82,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) {
|
|||||||
|
|
||||||
func TestServer_FindModelInPath(t *testing.T) {
|
func TestServer_FindModelInPath(t *testing.T) {
|
||||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||||
|
"author": {},
|
||||||
"author/model": {},
|
"author/model": {},
|
||||||
"simple": {},
|
"simple": {},
|
||||||
}}
|
}}
|
||||||
@@ -91,13 +96,14 @@ func TestServer_FindModelInPath(t *testing.T) {
|
|||||||
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||||
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||||
{"/author/model", "author/model", "/", true},
|
{"/author/model", "author/model", "/", true},
|
||||||
|
{"/author/v1/chat", "author", "/v1/chat", true},
|
||||||
{"/missing/v1", "", "", false},
|
{"/missing/v1", "", "", false},
|
||||||
{"/", "", "", false},
|
{"/", "", "", false},
|
||||||
}
|
}
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
name, _, rem, found := findModelInPath(cfg, c.path)
|
name, _, rem, found := shared.FindModelInPath(cfg, c.path)
|
||||||
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||||
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||||
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -133,6 +139,80 @@ func TestServer_HandleUpstream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func upstreamMetricsServer(response string) *Server {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
proxylog := logmon.NewWriter(io.Discard)
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
muxlog: logmon.NewWriter(io.Discard),
|
||||||
|
proxylog: proxylog,
|
||||||
|
upstreamlog: logmon.NewWriter(io.Discard),
|
||||||
|
inflight: &inflightCounter{},
|
||||||
|
metrics: newMetricsMonitor(proxylog, 10, 0),
|
||||||
|
local: newStubRouter([]string{"m1"}, response),
|
||||||
|
peer: newStubRouter(nil, ""),
|
||||||
|
}
|
||||||
|
s.routes()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
||||||
|
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
||||||
|
s := upstreamMetricsServer(resp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != resp {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
entries := s.metrics.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 metrics entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Model != "m1" {
|
||||||
|
t.Errorf("model = %q, want m1", entries[0].Model)
|
||||||
|
}
|
||||||
|
if entries[0].ReqPath != "/v1/chat/completions" {
|
||||||
|
t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath)
|
||||||
|
}
|
||||||
|
if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 {
|
||||||
|
t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) {
|
||||||
|
s := upstreamMetricsServer("ok")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
s.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "ok" {
|
||||||
|
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if len(s.metrics.getMetrics()) != 0 {
|
||||||
|
t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) {
|
||||||
|
s := upstreamMetricsServer(`{"usage":{}}`)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d", w.Code)
|
||||||
|
}
|
||||||
|
if len(s.metrics.getMetrics()) != 0 {
|
||||||
|
t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
|||||||
case "upstream":
|
case "upstream":
|
||||||
return s.upstreamlog, nil
|
return s.upstreamlog, nil
|
||||||
default:
|
default:
|
||||||
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||||
if log, ok := s.local.ProcessLogger(modelID); ok {
|
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
@@ -21,8 +22,27 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine the model-routed endpoint path. Regular routes are
|
||||||
|
// already meterable; /upstream/<model>/<path> is metered only when
|
||||||
|
// the remaining path matches a model-dispatched endpoint.
|
||||||
|
checkPath := r.URL.Path
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||||
|
var found bool
|
||||||
|
_, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||||
|
if !found {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isMetricsRecordPath(checkPath) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve the model now so downstream dispatch hits the context
|
// Resolve the model now so downstream dispatch hits the context
|
||||||
// fast path; FetchContext restores the request body.
|
// fast path; FetchContext restores the request body for regular
|
||||||
|
// routes and extracts the model from the URL for /upstream routes.
|
||||||
data, err := shared.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
shared.SendError(w, r, shared.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
@@ -31,7 +51,7 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
|||||||
|
|
||||||
// Buffer the request body/headers for capture before dispatch
|
// Buffer the request body/headers for capture before dispatch
|
||||||
// consumes them.
|
// consumes them.
|
||||||
cf := captureFieldsFor(r.URL.Path)
|
cf := captureFieldsFor(checkPath)
|
||||||
var reqBody []byte
|
var reqBody []byte
|
||||||
var reqHeaders map[string]string
|
var reqHeaders map[string]string
|
||||||
if mm.enableCaptures {
|
if mm.enableCaptures {
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -103,3 +106,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) {
|
|||||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
|
||||||
|
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
|
||||||
|
// mask (headers only) rather than falling back to captureAll.
|
||||||
|
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
|
||||||
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "audio/mpeg")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("BINARY-AUDIO-DATA"))
|
||||||
|
})
|
||||||
|
handler := CreateMetricsMiddleware(mm, cfg)(inner)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
|
||||||
|
handler.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) == 0 {
|
||||||
|
t.Fatal("no metrics recorded")
|
||||||
|
}
|
||||||
|
last := entries[len(entries)-1]
|
||||||
|
if !last.HasCapture {
|
||||||
|
t.Fatal("expected capture to be stored")
|
||||||
|
}
|
||||||
|
cap := mm.getCaptureByID(last.ID)
|
||||||
|
if cap == nil {
|
||||||
|
t.Fatal("capture not found")
|
||||||
|
}
|
||||||
|
if len(cap.RespBody) != 0 {
|
||||||
|
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
|
||||||
|
}
|
||||||
|
if len(cap.RespHeaders) == 0 {
|
||||||
|
t.Error("RespHeaders not stored; want captureRespHeaders mask")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -89,6 +89,27 @@ var modelGetRoutes = []string{
|
|||||||
"/sdapi/v1/loras",
|
"/sdapi/v1/loras",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isMetricsRecordPath reports whether path is one of the model-dispatched
|
||||||
|
// endpoints that the metrics middleware records in the activity log.
|
||||||
|
func isMetricsRecordPath(path string) bool {
|
||||||
|
for _, p := range modelPostJSONRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range modelPostFormRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range modelGetRoutes {
|
||||||
|
if p == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// BuildInfo carries version metadata surfaced by GET /api/version.
|
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||||
type BuildInfo struct {
|
type BuildInfo struct {
|
||||||
Version string
|
Version string
|
||||||
@@ -219,9 +240,11 @@ func (s *Server) routes() {
|
|||||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||||
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||||
|
|
||||||
// Upstream passthrough.
|
// Upstream passthrough. Meter only the model-dispatched endpoints that can
|
||||||
|
// produce token usage/timings.
|
||||||
|
upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg))
|
||||||
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||||
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream))
|
||||||
|
|
||||||
// API group (API-key protected) consumed by the UI.
|
// API group (API-key protected) consumed by the UI.
|
||||||
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||||
|
|||||||
+65
-4
@@ -91,16 +91,24 @@ func SendResponse(w http.ResponseWriter, r *http.Request, status int, message st
|
|||||||
w.Write(resp)
|
w.Write(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchContext will attempt to get the model id from the context then
|
// FetchContext will attempt to get the model id from the context, then
|
||||||
// from the model body. If it extracts the model from the body it will
|
// from an /upstream/<model> path prefix, then from the request body/query.
|
||||||
// store the model in the context for downstream handlers. An error
|
// If it extracts the model it will store it in the context for downstream
|
||||||
// will be returned when model can not be fetch from either location.
|
// handlers. An error will be returned when a model cannot be identified.
|
||||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||||
data, ok := ReadContext(r.Context())
|
data, ok := ReadContext(r.Context())
|
||||||
if ok {
|
if ok {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/upstream/") {
|
||||||
|
if data, ok := extractUpstreamContext(r, cfg); ok {
|
||||||
|
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
|
}
|
||||||
|
|
||||||
if data, err := extractContext(r); err == nil && data.Model != "" {
|
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||||
realName, _ := cfg.RealModelName(data.Model)
|
realName, _ := cfg.RealModelName(data.Model)
|
||||||
if realName == "" {
|
if realName == "" {
|
||||||
@@ -117,6 +125,59 @@ func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
|||||||
return ReqContextData{}, ErrNoModelInContext
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
|
||||||
|
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
|
||||||
|
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
|
||||||
|
if !found {
|
||||||
|
return ReqContextData{}, false
|
||||||
|
}
|
||||||
|
return ReqContextData{
|
||||||
|
Model: searchName,
|
||||||
|
ModelID: realName,
|
||||||
|
ApiKey: ExtractAPIKey(r),
|
||||||
|
Streaming: r.URL.Query().Get("stream") == "true",
|
||||||
|
SendLoadingState: sendLoadingState(cfg, realName),
|
||||||
|
Metadata: make(map[string]string),
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendLoadingState reports whether the configured model wants loading-state SSEs.
|
||||||
|
func sendLoadingState(cfg config.Config, modelID string) bool {
|
||||||
|
if mc, ok := cfg.Models[modelID]; ok {
|
||||||
|
return mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindModelInPath walks a slash-separated path, building up segments until one
|
||||||
|
// matches a configured model. This resolves model names that contain slashes
|
||||||
|
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||||
|
// remaining path, and whether a match was found.
|
||||||
|
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||||
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||||
|
name := ""
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = part
|
||||||
|
} else {
|
||||||
|
name = name + "/" + part
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelID, ok := cfg.RealModelName(name); ok {
|
||||||
|
searchName = name
|
||||||
|
realName = modelID
|
||||||
|
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||||
return context.WithValue(ctx, ReqContextKey, data)
|
return context.WithValue(ctx, ReqContextKey, data)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractContext_GET(t *testing.T) {
|
func TestExtractContext_GET(t *testing.T) {
|
||||||
@@ -456,3 +458,68 @@ func TestServer_ExtractAPIKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFetchContext_UpstreamPath(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"m1": {},
|
||||||
|
"author/model": {},
|
||||||
|
"real": {Aliases: []string{"nick"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
wantModel string
|
||||||
|
wantModelID string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
|
||||||
|
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
|
||||||
|
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
|
||||||
|
{"bare model path", "/upstream/m1/", "m1", "m1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
|
||||||
|
data, err := FetchContext(r, cfg)
|
||||||
|
if (err != nil) != c.wantErr {
|
||||||
|
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
|
||||||
|
}
|
||||||
|
if c.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data.Model != c.wantModel {
|
||||||
|
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
|
||||||
|
}
|
||||||
|
if data.ModelID != c.wantModelID {
|
||||||
|
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
|
||||||
|
}
|
||||||
|
if data.Metadata == nil {
|
||||||
|
t.Error("metadata map not initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
|
||||||
|
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||||
|
body := `{"model":"should-not-matter"}`
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
|
||||||
|
|
||||||
|
_, err := FetchContext(r, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The body should be untouched so the upstream handler can still read it.
|
||||||
|
got, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read body: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != body {
|
||||||
|
t.Errorf("body was consumed: %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user