diff --git a/internal/router/base.go b/internal/router/base.go index ab6a3c4d..a14d2e09 100644 --- a/internal/router/base.go +++ b/internal/router/base.go @@ -12,6 +12,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/router/scheduler" + "github.com/mostlygeek/llama-swap/internal/shared" ) type shutdownReq struct { @@ -399,13 +400,13 @@ func (b *baseRouter) Shutdown(timeout time.Duration) error { func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { if b.shuttingDown.Load() { - SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) return } - data, err := FetchContext(req, b.config) + data, err := shared.FetchContext(req, b.config) if err != nil { - SendError(w, req, err) + shared.SendError(w, req, err) return } @@ -424,7 +425,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { case <-req.Context().Done(): return case <-b.shutdownCtx.Done(): - SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) return } @@ -475,12 +476,12 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return case <-b.shutdownCtx.Done(): finishLoading() - SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) return } if resp.Err != nil { - SendError(w, req, resp.Err) + shared.SendError(w, req, resp.Err) return } resp.HandleFunc(w, req) diff --git a/internal/router/loading_test.go b/internal/router/loading_test.go index 57dc3bf6..6863708e 100644 --- a/internal/router/loading_test.go +++ b/internal/router/loading_test.go @@ -226,69 +226,6 @@ func TestIsLoadingPath(t *testing.T) { } } -func TestExtractContext_Streaming_GET(t *testing.T) { - tests := []struct { - name string - query string - wantStreaming bool - }{ - {"streaming true", "model=llama3&stream=true", true}, - {"streaming false", "model=llama3&stream=false", false}, - {"no stream param", "model=llama3", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) - got, err := ExtractContext(r) - if err != nil { - t.Fatalf("ExtractContext: %v", err) - } - if got.Streaming != tt.wantStreaming { - t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) - } - }) - } -} - -func TestExtractContext_Streaming_JSON(t *testing.T) { - tests := []struct { - name string - body string - wantStreaming bool - }{ - {"streaming true", `{"model":"llama3","stream":true}`, true}, - {"streaming false", `{"model":"llama3","stream":false}`, false}, - {"no stream param", `{"model":"llama3"}`, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) - r.Header.Set("Content-Type", "application/json") - got, err := ExtractContext(r) - if err != nil { - t.Fatalf("ExtractContext: %v", err) - } - if got.Streaming != tt.wantStreaming { - t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) - } - }) - } -} - -func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) { - r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true")) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - got, err := ExtractContext(r) - if err != nil { - t.Fatalf("ExtractContext: %v", err) - } - if !got.Streaming { - t.Error("Streaming should be true") - } -} - func countSSEMessages(s string) int { scanner := bufio.NewScanner(strings.NewReader(s)) count := 0 diff --git a/internal/router/peer.go b/internal/router/peer.go index fb17068c..e46447a6 100644 --- a/internal/router/peer.go +++ b/internal/router/peer.go @@ -15,6 +15,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/shared" ) type peerMember struct { @@ -146,22 +147,22 @@ func (r *Peer) Shutdown(timeout time.Duration) error { func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if r.shuttingDown.Load() { - SendError(w, req, fmt.Errorf("peer proxy is shutting down")) + shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down")) return } r.inflight.Add(1) defer r.inflight.Done() - data, err := FetchContext(req, r.cfg) + data, err := shared.FetchContext(req, r.cfg) if err != nil { - SendError(w, req, err) + shared.SendError(w, req, err) return } pp, found := r.peers[data.ModelID] if !found { r.logger.Warnf("peer model not found: %s", data.ModelID) - SendError(w, req, ErrNoPeerModelFound) + shared.SendError(w, req, ErrNoPeerModelFound) return } diff --git a/internal/router/peer_test.go b/internal/router/peer_test.go index 74527bcb..1641c420 100644 --- a/internal/router/peer_test.go +++ b/internal/router/peer_test.go @@ -12,6 +12,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/shared" ) var testLogger = logmon.NewWriter(os.Stdout) @@ -142,7 +143,7 @@ func TestPeer_ServeHTTP_Success(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -178,7 +179,7 @@ func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -212,7 +213,7 @@ func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -246,7 +247,7 @@ func TestPeer_ServeHTTP_NoApiKey(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -279,7 +280,7 @@ func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -311,7 +312,7 @@ func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -347,7 +348,7 @@ func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -385,7 +386,7 @@ func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) var wg sync.WaitGroup wg.Add(1) @@ -448,7 +449,7 @@ func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) var wg sync.WaitGroup wg.Add(1) @@ -551,7 +552,7 @@ func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) { body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`) req := httptest.NewRequest("POST", "/v1/chat/completions", body) req.Header.Set("Content-Type", "application/json") - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) diff --git a/internal/router/router.go b/internal/router/router.go index 8ebc1f11..561c5fc8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,40 +1,18 @@ package router import ( - "bytes" - "context" - "errors" - "fmt" - "io" "net/http" - "strings" "time" - "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" - "github.com/mostlygeek/llama-swap/internal/router/scheduler" - "github.com/tidwall/gjson" + "github.com/mostlygeek/llama-swap/internal/shared" ) -type contextkey struct { - name string -} - -type ReqContextData struct { - Model string - ModelID string - Streaming bool - SendLoadingState bool -} - var ( - ErrNoModelInContext = fmt.Errorf("no model in request context") - ErrNoRouterFound = fmt.Errorf("no router found for model") - ErrNoPeerModelFound = fmt.Errorf("peer model not found") - ErrNoLocalModelFound = scheduler.ErrModelNotFound - - ContextKey = &contextkey{"context"} + ErrNoRouterFound = shared.ErrNoRouterFound + ErrNoPeerModelFound = shared.ErrNoPeerModelFound + ErrNoLocalModelFound = shared.ErrNoLocalModelFound ) type Router interface { @@ -72,129 +50,3 @@ type LocalRouter interface { // model is not known to this router. ProcessLogger(modelID string) (*logmon.Monitor, bool) } - -// FetchContext will attempt to get the model id from the context then -// from the model body. If it extracts the model from the body it will -// store the model in the context for downstream handlers. An error -// will be returned when model can not be fetch from either location. -func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { - data, ok := ReadContext(r.Context()) - if ok { - return data, nil - } - - if data, err := ExtractContext(r); err == nil { - realName, _ := cfg.RealModelName(data.Model) - if realName == "" { - realName = data.Model - } - data.ModelID = realName - if mc, ok := cfg.Models[realName]; ok { - data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState - } - *r = *r.WithContext(SetContext(r.Context(), data)) - return data, nil - } - - return ReqContextData{}, ErrNoModelInContext -} - -func SetContext(ctx context.Context, data ReqContextData) context.Context { - return context.WithValue(ctx, ContextKey, data) -} - -func ReadContext(ctx context.Context) (ReqContextData, bool) { - data, ok := ctx.Value(ContextKey).(ReqContextData) - return data, ok -} - -// ExtractContext pulls the model name from an HTTP request without consuming the -// body. For GET requests it reads the "model" query parameter. For POST -// requests it inspects Content-Type and parses JSON, multipart/form-data, or -// application/x-www-form-urlencoded bodies. The request body is always restored -// before returning so downstream handlers — including reverse proxies that -// forward raw bytes upstream — can still read it. -func ExtractContext(r *http.Request) (ReqContextData, error) { - if r.Method == http.MethodGet { - if model := r.URL.Query().Get("model"); model != "" { - return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil - } - return ReqContextData{}, fmt.Errorf("missing 'model' query parameter") - } - - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - return ReqContextData{}, fmt.Errorf("error reading request body: %w", err) - } - defer func() { - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - }() - - contentType := r.Header.Get("Content-Type") - - if strings.Contains(contentType, "application/json") { - model := gjson.GetBytes(bodyBytes, "model").String() - if model == "" { - return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body") - } - return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil - } - - // Form parsers read from r.Body, so feed them a fresh reader over the - // buffered bytes. The deferred restore above will reset r.Body again - // after parsing. - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - if strings.Contains(contentType, "multipart/form-data") { - if err := r.ParseMultipartForm(32 << 20); err != nil { - return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err) - } - } else { - if err := r.ParseForm(); err != nil { - return ReqContextData{}, fmt.Errorf("error parsing form: %w", err) - } - } - - if model := r.FormValue("model"); model != "" { - return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil - } - - return ReqContextData{}, fmt.Errorf("missing 'model' parameter") -} - -func SendError(w http.ResponseWriter, r *http.Request, err error) { - switch { - case errors.Is(err, ErrNoModelInContext): - SendResponse(w, r, http.StatusNotFound, "no model id could be identified") - case errors.Is(err, ErrNoPeerModelFound): - SendResponse(w, r, http.StatusNotFound, "no peer found for requested model") - case errors.Is(err, ErrNoLocalModelFound): - SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") - case errors.Is(err, ErrNoRouterFound): - SendResponse(w, r, http.StatusNotFound, "no router for requested model") - default: - SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err)) - } -} - -// SendResponse detects what content type the client prefers and returns an error response in that format. -func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) { - // Check Accept header for preferred response format - acceptHeader := r.Header.Get("Accept") - if strings.Contains(acceptHeader, "text/plain") { - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf("llama-swap: %s", message))) - return - } - - if strings.Contains(acceptHeader, "text/html") { - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf(`

llama-swap

%s

`, message))) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message))) -} diff --git a/internal/router/scheduler/scheduler.go b/internal/router/scheduler/scheduler.go index 9dc9e281..87ed6ad2 100644 --- a/internal/router/scheduler/scheduler.go +++ b/internal/router/scheduler/scheduler.go @@ -11,17 +11,17 @@ package scheduler import ( "context" - "fmt" "net/http" "time" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" + "github.com/mostlygeek/llama-swap/internal/shared" ) // ErrModelNotFound is granted to callers whose model is not handled by this -// router. The router package aliases it so SendError can match it. -var ErrModelNotFound = fmt.Errorf("local model not found") +// router. It is an alias for shared.ErrNoLocalModelFound. +var ErrModelNotFound = shared.ErrNoLocalModelFound // Swapper is the eviction policy: it decides which running models must be // stopped before a target can serve. It is orthogonal to the scheduling diff --git a/internal/server/api.go b/internal/server/api.go index d3723c63..fe11ac2b 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -9,7 +9,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/router" "github.com/mostlygeek/llama-swap/internal/shared" ) @@ -163,7 +162,7 @@ func (s *Server) startPreload() { if err != nil { continue } - req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID})) + req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID})) dw := &discardResponseWriter{status: http.StatusOK} s.local.ServeHTTP(dw, req) @@ -208,7 +207,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath) if !found { - router.SendResponse(w, r, http.StatusNotFound, "model not found") + shared.SendResponse(w, r, http.StatusNotFound, "model not found") return } @@ -230,7 +229,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { // Strip the /upstream/ prefix before forwarding. r.URL.Path = remainingPath // Pin the resolved model so the router skips body/query extraction. - *r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID})) + *r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID})) switch { case s.local.Handles(modelID): @@ -238,7 +237,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { case s.peer.Handles(modelID): s.peer.ServeHTTP(w, r) default: - router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID) + shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID) } } diff --git a/internal/server/apigroup.go b/internal/server/apigroup.go index 1e3131bb..c989ae2b 100644 --- a/internal/server/apigroup.go +++ b/internal/server/apigroup.go @@ -12,7 +12,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/perf" - "github.com/mostlygeek/llama-swap/internal/router" "github.com/mostlygeek/llama-swap/internal/shared" ) @@ -76,11 +75,11 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) { requested := strings.TrimPrefix(r.PathValue("model"), "/") realName, found := s.cfg.RealModelName(requested) if !found { - router.SendResponse(w, r, http.StatusNotFound, "model not found") + shared.SendResponse(w, r, http.StatusNotFound, "model not found") return } if !s.local.Handles(realName) { - router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") + shared.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") return } s.local.Unload(apiUnloadTimeout, realName) @@ -92,7 +91,7 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) { data, err := s.metrics.getMetricsJSON() if err != nil { - router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics") + shared.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics") return } w.Header().Set("Content-Type", "application/json") @@ -103,7 +102,7 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) { // filtered to samples after the ?after= timestamp. func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) { if s.perf == nil { - router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available") + shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available") return } @@ -112,7 +111,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) { if afterStr := r.URL.Query().Get("after"); afterStr != "" { after, err := time.Parse(time.RFC3339, afterStr) if err != nil { - router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format") + shared.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format") return } filteredSys := make([]perf.SysStat, 0, len(sysStats)) @@ -153,19 +152,19 @@ func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(r.PathValue("id")) if err != nil { - router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID") + shared.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID") return } capture := s.metrics.getCaptureByID(id) if capture == nil { - router.SendResponse(w, r, http.StatusNotFound, "capture not found") + shared.SendResponse(w, r, http.StatusNotFound, "capture not found") return } jsonBytes, err := json.Marshal(capture) if err != nil { - router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture") + shared.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture") return } w.Header().Set("Content-Type", "application/json") @@ -198,7 +197,7 @@ func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { - router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported") + shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported") return } diff --git a/internal/server/auth.go b/internal/server/auth.go index e385b733..b78b4797 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -1,19 +1,17 @@ package server import ( - "encoding/base64" "net/http" "strings" "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" ) // CreateAuthMiddleware returns middleware that validates API keys when the // config declares any. It accepts the key via Authorization: Bearer, -// Authorization: Basic (password field), or x-api-key. On success the auth -// headers are stripped so they never leak to upstream. When no keys are +// Authorization: Basic (password field), or x-api-key. When no keys are // configured the middleware is a pass-through. func CreateAuthMiddleware(cfg config.Config) chain.Middleware { keys := cfg.RequiredAPIKeys @@ -22,7 +20,7 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - provided := extractAPIKey(r) + provided := shared.ExtractAPIKey(r) valid := false for _, key := range keys { @@ -33,41 +31,29 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware { } if !valid { w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`) - router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key") + shared.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key") return } - r.Header.Del("Authorization") - r.Header.Del("x-api-key") next.ServeHTTP(w, r) }) } } -// extractAPIKey pulls a candidate API key from the request, preferring Basic, -// then Bearer, then x-api-key. -func extractAPIKey(r *http.Request) string { - var bearerKey, basicKey string - if auth := r.Header.Get("Authorization"); auth != "" { - if strings.HasPrefix(auth, "Bearer ") { - bearerKey = strings.TrimPrefix(auth, "Bearer ") - } else if strings.HasPrefix(auth, "Basic ") { - encoded := strings.TrimPrefix(auth, "Basic ") - if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil { - if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 { - basicKey = parts[1] // password field is the API key - } +// CreateRequestContextMiddleware returns middleware that extracts model and +// auth info from the request into the context. Requests where no model can be +// identified are rejected with a 404. +func CreateRequestContextMiddleware(cfg config.Config) chain.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, err := shared.FetchContext(r, cfg) + if err != nil { + shared.SendError(w, r, shared.ErrNoModelInContext) + return } - } - } - - switch { - case basicKey != "": - return basicKey - case bearerKey != "": - return bearerKey - default: - return r.Header.Get("x-api-key") + _ = data + next.ServeHTTP(w, r) + }) } } diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index e722e4a7..31e419f3 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -1,48 +1,14 @@ package server import ( - "encoding/base64" "net/http" "net/http/httptest" + "strings" "testing" "github.com/mostlygeek/llama-swap/internal/config" ) -func TestServer_ExtractAPIKey(t *testing.T) { - basicHeader := func(user, pass string) string { - return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) - } - cases := []struct { - name string - auth string - xapi string - want string - }{ - {"none", "", "", ""}, - {"bearer", "Bearer tok123", "", "tok123"}, - {"basic", basicHeader("user", "pw-key"), "", "pw-key"}, - {"x-api-key", "", "xkey", "xkey"}, - {"basic beats bearer", basicHeader("u", "bk"), "", "bk"}, - {"bearer beats x-api-key", "Bearer btok", "xkey", "btok"}, - {"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"}, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - if c.auth != "" { - r.Header.Set("Authorization", c.auth) - } - if c.xapi != "" { - r.Header.Set("x-api-key", c.xapi) - } - if got := extractAPIKey(r); got != c.want { - t.Errorf("extractAPIKey() = %q, want %q", got, c.want) - } - }) - } -} - func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) { cases := []struct { in string @@ -74,11 +40,42 @@ func TestServer_IsTokenChar(t *testing.T) { } } +func TestServer_RequestContextMiddleware(t *testing.T) { + cfg := config.Config{ + Models: map[string]config.ModelConfig{ + "llama3": {}, + }, + } + + final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + mw := CreateRequestContextMiddleware(cfg) + + t.Run("known model passes through", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"llama3"}`)) + r.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + mw(final).ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } + }) + + t.Run("missing model returns 404", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) + r.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + mw(final).ServeHTTP(w, r) + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", w.Code) + } + }) +} + func TestServer_AuthMiddleware(t *testing.T) { final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" { - t.Error("auth headers leaked to upstream") - } w.WriteHeader(http.StatusOK) }) diff --git a/internal/server/concurrency.go b/internal/server/concurrency.go index ddb05c13..ccc339f3 100644 --- a/internal/server/concurrency.go +++ b/internal/server/concurrency.go @@ -7,7 +7,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" ) // defaultConcurrencyLimit caps simultaneous in-flight requests per model when @@ -32,9 +32,9 @@ func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - data, err := router.FetchContext(r, cfg) + data, err := shared.FetchContext(r, cfg) if err != nil { - router.SendError(w, r, router.ErrNoModelInContext) + shared.SendError(w, r, shared.ErrNoModelInContext) return } diff --git a/internal/server/concurrency_test.go b/internal/server/concurrency_test.go index c9aa91f9..9cc68f97 100644 --- a/internal/server/concurrency_test.go +++ b/internal/server/concurrency_test.go @@ -7,12 +7,12 @@ import ( "testing" "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" ) func concurrencyTestReq(model string) *http.Request { r := httptest.NewRequest("GET", "/v1/chat/completions", nil) - return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model})) + return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model})) } func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) { diff --git a/internal/server/filters.go b/internal/server/filters.go index 209e14c6..95e61953 100644 --- a/internal/server/filters.go +++ b/internal/server/filters.go @@ -11,7 +11,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" "github.com/tidwall/sjson" ) @@ -34,9 +34,9 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware { return } - data, err := router.FetchContext(r, cfg) + data, err := shared.FetchContext(r, cfg) if err != nil { - router.SendError(w, r, router.ErrNoModelInContext) + shared.SendError(w, r, shared.ErrNoModelInContext) return } @@ -48,13 +48,13 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware { body, err := io.ReadAll(r.Body) if err != nil { - router.SendResponse(w, r, http.StatusBadRequest, "could not read request body") + shared.SendResponse(w, r, http.StatusBadRequest, "could not read request body") return } body, err = applyFilters(body, data.Model, useModelName, filters) if err != nil { - router.SendResponse(w, r, http.StatusInternalServerError, err.Error()) + shared.SendResponse(w, r, http.StatusInternalServerError, err.Error()) return } @@ -84,9 +84,9 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware { return } - data, err := router.FetchContext(r, cfg) + data, err := shared.FetchContext(r, cfg) if err != nil { - router.SendError(w, r, router.ErrNoModelInContext) + shared.SendError(w, r, shared.ErrNoModelInContext) return } @@ -97,13 +97,13 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware { } if err := r.ParseMultipartForm(32 << 20); err != nil { - router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) + shared.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) return } body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName) if err != nil { - router.SendResponse(w, r, http.StatusInternalServerError, err.Error()) + shared.SendResponse(w, r, http.StatusInternalServerError, err.Error()) return } diff --git a/internal/server/log.go b/internal/server/log.go index a41d5952..bd776780 100644 --- a/internal/server/log.go +++ b/internal/server/log.go @@ -14,7 +14,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" - "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" ) // NewLoggers builds the proxy, upstream, and combined (mux) log monitors, @@ -102,13 +102,13 @@ func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) { logger, err := s.getLogger(logMonitorID) if err != nil { - router.SendResponse(w, r, http.StatusBadRequest, err.Error()) + shared.SendResponse(w, r, http.StatusBadRequest, err.Error()) return } flusher, ok := w.(http.Flusher) if !ok { - router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported") + shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported") return } diff --git a/internal/server/metrics_middleware.go b/internal/server/metrics_middleware.go index b52a705a..8ba41825 100644 --- a/internal/server/metrics_middleware.go +++ b/internal/server/metrics_middleware.go @@ -7,7 +7,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" ) // CreateMetricsMiddleware returns middleware that records token metrics for @@ -23,9 +23,9 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle // Resolve the model now so downstream dispatch hits the context // fast path; FetchContext restores the request body. - data, err := router.FetchContext(r, cfg) + data, err := shared.FetchContext(r, cfg) if err != nil { - router.SendError(w, r, router.ErrNoModelInContext) + shared.SendError(w, r, shared.ErrNoModelInContext) return } diff --git a/internal/server/server.go b/internal/server/server.go index 5f36de32..f2ad15ab 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/perf" "github.com/mostlygeek/llama-swap/internal/router" + "github.com/mostlygeek/llama-swap/internal/shared" ) // Server owns the HTTP mux, cross-cutting middleware, and the local/peer model @@ -138,13 +139,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up } // localPeerHandler dispatches a model-routed request to the local or peer -// router. The model is resolved once via router.FetchContext. +// router. The model is resolved once via shared.FetchContext. func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) { stripVersionPrefix(r) - data, err := router.FetchContext(r, s.cfg) + data, err := shared.FetchContext(r, s.cfg) if err != nil { - router.SendError(w, r, router.ErrNoModelInContext) + shared.SendError(w, r, shared.ErrNoModelInContext) return } @@ -156,7 +157,7 @@ func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) { s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID) s.peer.ServeHTTP(w, r) default: - router.SendError(w, r, router.ErrNoRouterFound) + shared.SendError(w, r, router.ErrNoRouterFound) } } @@ -171,21 +172,14 @@ func stripVersionPrefix(r *http.Request) { // routes builds the mux, registers every route, and wraps the mux with the // global CORS middleware. func (s *Server) routes() { - authMW := CreateAuthMiddleware(s.cfg) - filterMW := CreateFilterMiddleware(s.cfg) - formFilterMW := CreateFormFilterMiddleware(s.cfg) - // Model-dispatched routes get auth + per-model concurrency limiting + body - // filters + in-flight tracking + token metrics. concurrencyMW rejects with - // 429 before the body filters do any rewrite work. filterMW rewrites JSON - // bodies and formFilterMW rewrites multipart bodies; each is a no-op for the - // other's Content-Type. Both run before the metrics middleware so it buffers - // the rewritten body. + authMW := CreateAuthMiddleware(s.cfg) modelChain := chain.New( authMW, + CreateRequestContextMiddleware(s.cfg), CreateConcurrencyMiddleware(s.cfg), - filterMW, - formFilterMW, + CreateFilterMiddleware(s.cfg), + CreateFormFilterMiddleware(s.cfg), CreateInflightMiddleware(s.inflight), CreateMetricsMiddleware(s.metrics, s.cfg), ) @@ -216,11 +210,11 @@ func (s *Server) routes() { mux.HandleFunc("GET /{$}", handleRootRedirect) // Embedded UI. - mux.HandleFunc("GET /ui/", s.handleUI) + mux.Handle("GET /ui/", chain.New(authMW).ThenFunc(s.handleUI)) mux.HandleFunc("GET /favicon.ico", s.handleFavicon) - // Prometheus metrics (no auth, matches the legacy endpoint). - mux.HandleFunc("GET /metrics", s.handleMetrics) + // Prometheus metrics (wrapped by apiChain, matches the legacy endpoint). + mux.Handle("GET /metrics", apiChain.ThenFunc(s.handleMetrics)) // Operations endpoints. mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload)) diff --git a/internal/shared/http.go b/internal/shared/http.go new file mode 100644 index 00000000..adc19e87 --- /dev/null +++ b/internal/shared/http.go @@ -0,0 +1,202 @@ +package shared + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "net/http" + "strings" + + "github.com/mostlygeek/llama-swap/internal/config" + "github.com/tidwall/gjson" +) + +type contextkey struct { + name string +} + +type ReqContextData struct { + ApiKey string + Model string + ModelID string + Streaming bool + SendLoadingState bool +} + +var ( + ReqContextKey = &contextkey{"context"} + ErrNoModelInContext = fmt.Errorf("no model in request context") + ErrNoRouterFound = fmt.Errorf("no router found for model") + ErrNoPeerModelFound = fmt.Errorf("peer model not found") + ErrNoLocalModelFound = fmt.Errorf("local model not found") +) + +func SendError(w http.ResponseWriter, r *http.Request, err error) { + switch { + case errors.Is(err, ErrNoModelInContext): + SendResponse(w, r, http.StatusNotFound, "no model id could be identified") + case errors.Is(err, ErrNoPeerModelFound): + SendResponse(w, r, http.StatusNotFound, "no peer found for requested model") + case errors.Is(err, ErrNoLocalModelFound): + SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") + case errors.Is(err, ErrNoRouterFound): + SendResponse(w, r, http.StatusNotFound, "no router for requested model") + default: + SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err)) + } +} + +// SendResponse detects what content type the client prefers and returns an error response in that format. +func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) { + acceptHeader := r.Header.Get("Accept") + if strings.Contains(acceptHeader, "text/plain") { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(status) + w.Write([]byte(fmt.Sprintf("llama-swap: %s", message))) + return + } + + if strings.Contains(acceptHeader, "text/html") { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(status) + w.Write([]byte(fmt.Sprintf(`

llama-swap

%s

`, html.EscapeString(message)))) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message}) + if err != nil { + w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`)) + return + } + w.Write(resp) +} + +// FetchContext will attempt to get the model id from the context then +// from the model body. If it extracts the model from the body it will +// store the model in the context for downstream handlers. An error +// will be returned when model can not be fetch from either location. +func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { + data, ok := ReadContext(r.Context()) + if ok { + return data, nil + } + + if data, err := extractContext(r); err == nil && data.Model != "" { + realName, _ := cfg.RealModelName(data.Model) + if realName == "" { + realName = data.Model + } + data.ModelID = realName + if mc, ok := cfg.Models[realName]; ok { + data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState + } + *r = *r.WithContext(SetContext(r.Context(), data)) + return data, nil + } + + return ReqContextData{}, ErrNoModelInContext +} + +func SetContext(ctx context.Context, data ReqContextData) context.Context { + return context.WithValue(ctx, ReqContextKey, data) +} + +func ReadContext(ctx context.Context) (ReqContextData, bool) { + data, ok := ctx.Value(ReqContextKey).(ReqContextData) + return data, ok +} + +// extractContext pulls fields from an HTTP request into a ReqContextData, +// returning whatever is available. For GET requests it reads query parameters. +// For POST requests it inspects Content-Type and parses JSON, +// multipart/form-data, or application/x-www-form-urlencoded bodies. The +// request body is always restored before returning. An error is returned only +// for I/O or parse failures, not for missing fields. +func extractContext(r *http.Request) (ReqContextData, error) { + + apiKey := ExtractAPIKey(r) + + if r.Method == http.MethodGet { + q := r.URL.Query() + return ReqContextData{ + Model: q.Get("model"), + Streaming: q.Get("stream") == "true", + ApiKey: apiKey, + }, nil + } + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return ReqContextData{}, fmt.Errorf("error reading request body: %w", err) + } + defer func() { + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + }() + + contentType := r.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/json") { + return ReqContextData{ + Model: gjson.GetBytes(bodyBytes, "model").String(), + Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(), + ApiKey: apiKey, + }, nil + } + + // Form parsers read from r.Body, so feed them a fresh reader over the + // buffered bytes. The deferred restore above will reset r.Body again + // after parsing. + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + if strings.Contains(contentType, "multipart/form-data") { + if err := r.ParseMultipartForm(32 << 20); err != nil { + return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err) + } + } else { + if err := r.ParseForm(); err != nil { + return ReqContextData{}, fmt.Errorf("error parsing form: %w", err) + } + } + + return ReqContextData{ + Model: r.FormValue("model"), + Streaming: r.FormValue("stream") == "true", + ApiKey: apiKey, + }, nil +} + +// extractAPIKey pulls a candidate API key from the request, preferring Basic, +// then Bearer, then x-api-key. +func ExtractAPIKey(r *http.Request) string { + var bearerKey, basicKey string + if auth := r.Header.Get("Authorization"); auth != "" { + scheme, credentials, ok := strings.Cut(auth, " ") + if ok { + switch strings.ToLower(scheme) { + case "bearer": + bearerKey = credentials + case "basic": + if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil { + if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 { + basicKey = parts[1] // password field is the API key + } + } + } + } + } + + switch { + case basicKey != "": + return basicKey + case bearerKey != "": + return bearerKey + default: + return r.Header.Get("x-api-key") + } +} diff --git a/internal/router/router_test.go b/internal/shared/http_test.go similarity index 56% rename from internal/router/router_test.go rename to internal/shared/http_test.go index fa88364c..8bbdd69c 100644 --- a/internal/router/router_test.go +++ b/internal/shared/http_test.go @@ -1,11 +1,13 @@ -package router +package shared import ( "bytes" "context" + "encoding/base64" "io" "mime/multipart" "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -20,13 +22,13 @@ func TestExtractContext_GET(t *testing.T) { }{ {"model present", "model=llama3", "llama3", false}, {"model with slashes", "model=author/model-7b", "author/model-7b", false}, - {"model missing", "", "", true}, + {"model missing", "", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -46,16 +48,16 @@ func TestExtractContext_JSON(t *testing.T) { }{ {"model present", `{"model":"llama3","stream":true}`, "llama3", false}, {"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false}, - {"model empty string", `{"model":""}`, "", true}, - {"model key missing", `{"stream":true}`, "", true}, - {"invalid json", `not-json`, "", true}, + {"model empty string", `{"model":""}`, "", false}, + {"model key missing", `{"stream":true}`, "", false}, + {"invalid json", `not-json`, "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) r.Header.Set("Content-Type", "application/json") - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -74,7 +76,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) { wantErr bool }{ {"model present", "whisper-1", "whisper-1", false}, - {"model missing", "", "", true}, + {"model missing", "", "", false}, } for _, tt := range tests { @@ -85,7 +87,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) { } r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -104,7 +106,7 @@ func TestExtractContext_MultipartForm(t *testing.T) { wantErr bool }{ {"model present", "whisper-1", "whisper-1", false}, - {"model missing", "", "", true}, + {"model missing", "", "", false}, } for _, tt := range tests { @@ -119,7 +121,7 @@ func TestExtractContext_MultipartForm(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf) r.Header.Set("Content-Type", mw.FormDataContentType()) - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -135,7 +137,7 @@ func TestExtractContext_JSONBodyRestored(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) r.Header.Set("Content-Type", "application/json") - if _, err := ExtractContext(r); err != nil { + if _, err := extractContext(r); err != nil { t.Fatalf("ExtractContext: %v", err) } @@ -162,7 +164,7 @@ func TestExtractContext_MultipartBodyRestored(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original)) r.Header.Set("Content-Type", mw.FormDataContentType()) - if _, err := ExtractContext(r); err != nil { + if _, err := extractContext(r); err != nil { t.Fatalf("ExtractContext: %v", err) } @@ -180,7 +182,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if _, err := ExtractContext(r); err != nil { + if _, err := extractContext(r); err != nil { t.Fatalf("ExtractContext: %v", err) } @@ -195,7 +197,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) { func TestSetContext(t *testing.T) { ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}) - data, ok := ctx.Value(ContextKey).(ReqContextData) + data, ok := ctx.Value(ReqContextKey).(ReqContextData) if !ok { t.Fatalf("ContextKey not set or wrong type") } @@ -209,7 +211,7 @@ func TestSetContext(t *testing.T) { func TestSetContext_WithAlias(t *testing.T) { ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}) - data, _ := ctx.Value(ContextKey).(ReqContextData) + data, _ := ctx.Value(ReqContextKey).(ReqContextData) if data.Model != "llama" { t.Errorf("want requested %q got %q", "llama", data.Model) } @@ -221,7 +223,7 @@ func TestSetContext_WithAlias(t *testing.T) { func TestSetContext_DoesNotMutateParent(t *testing.T) { parent := context.Background() _ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"}) - if v := parent.Value(ContextKey); v != nil { + if v := parent.Value(ReqContextKey); v != nil { t.Errorf("parent context was mutated: %v", v) } } @@ -273,3 +275,152 @@ func TestReadContext(t *testing.T) { }) } } + +func TestExtractContext_Streaming_GET(t *testing.T) { + tests := []struct { + name string + query string + wantStreaming bool + }{ + {"streaming true", "model=llama3&stream=true", true}, + {"streaming false", "model=llama3&stream=false", false}, + {"no stream param", "model=llama3", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) + got, err := extractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if got.Streaming != tt.wantStreaming { + t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) + } + }) + } +} + +func TestExtractContext_Streaming_JSON(t *testing.T) { + tests := []struct { + name string + body string + wantStreaming bool + }{ + {"streaming true", `{"model":"llama3","stream":true}`, true}, + {"streaming false", `{"model":"llama3","stream":false}`, false}, + {"no stream param", `{"model":"llama3"}`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") + got, err := extractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if got.Streaming != tt.wantStreaming { + t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) + } + }) + } +} + +func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := extractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if !got.Streaming { + t.Error("Streaming should be true") + } +} + +func TestExtractContext_ApiKey(t *testing.T) { + basicHeader := func(user, pass string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) + } + cases := []struct { + name string + method string + ct string + body string + auth string + xapi string + wantKey string + }{ + {"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"}, + {"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"}, + {"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"}, + {"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"}, + {"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"}, + {"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"}, + {"no key", http.MethodGet, "", "", "", "", ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var body io.Reader + if c.body != "" { + body = strings.NewReader(c.body) + } + r, _ := http.NewRequest(c.method, "/", body) + if c.ct != "" { + r.Header.Set("Content-Type", c.ct) + } + if c.auth != "" { + r.Header.Set("Authorization", c.auth) + } + if c.xapi != "" { + r.Header.Set("x-api-key", c.xapi) + } + got, err := extractContext(r) + if err != nil { + t.Fatalf("extractContext: %v", err) + } + if got.ApiKey != c.wantKey { + t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey) + } + }) + } +} + +func TestServer_ExtractAPIKey(t *testing.T) { + basicHeader := func(user, pass string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) + } + cases := []struct { + name string + auth string + xapi string + want string + }{ + {"none", "", "", ""}, + {"bearer", "Bearer tok123", "", "tok123"}, + {"basic", basicHeader("user", "pw-key"), "", "pw-key"}, + {"x-api-key", "", "xkey", "xkey"}, + {"basic beats bearer", basicHeader("u", "bk"), "", "bk"}, + {"bearer beats x-api-key", "Bearer btok", "xkey", "btok"}, + {"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"}, + {"lowercase bearer", "bearer tok123", "", "tok123"}, + {"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"}, + {"mixed case BEARER", "BEARER tok456", "", "tok456"}, + {"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + if c.auth != "" { + r.Header.Set("Authorization", c.auth) + } + if c.xapi != "" { + r.Header.Set("x-api-key", c.xapi) + } + if got := ExtractAPIKey(r); got != c.want { + t.Errorf("extractAPIKey() = %q, want %q", got, c.want) + } + }) + } +}