From 62aea0e83df49f5d2722031df4818e79d15f1fef Mon Sep 17 00:00:00 2001 From: Benson Wong <83972+mostlygeek@users.noreply.github.com> Date: Sat, 13 Jun 2026 10:19:04 -0700 Subject: [PATCH] internal/router,server,shared: refactor auth, libs (#839) - refactor shared http functionality into internal/shared/http.go - remove stripping of Authorization and x-api-key - add Request Context middleware to internal/server - add /ui and /metrics behind auth middleware, fixes #717 Fix #717 Updates: #834 --- internal/router/base.go | 13 +- internal/router/loading_test.go | 63 ------ internal/router/peer.go | 9 +- internal/router/peer_test.go | 21 +- internal/router/router.go | 156 +------------- internal/router/scheduler/scheduler.go | 6 +- internal/server/api.go | 9 +- internal/server/apigroup.go | 19 +- internal/server/auth.go | 48 ++--- internal/server/auth_test.go | 73 +++---- internal/server/concurrency.go | 6 +- internal/server/concurrency_test.go | 4 +- internal/server/filters.go | 18 +- internal/server/log.go | 6 +- internal/server/metrics_middleware.go | 6 +- internal/server/server.go | 30 ++- internal/shared/http.go | 202 ++++++++++++++++++ .../router_test.go => shared/http_test.go} | 185 ++++++++++++++-- 18 files changed, 497 insertions(+), 377 deletions(-) create mode 100644 internal/shared/http.go rename internal/{router/router_test.go => shared/http_test.go} (56%) diff --git a/internal/router/base.go b/internal/router/base.go index ab6a3c4d..a14d2e09 100644 --- a/internal/router/base.go +++ b/internal/router/base.go @@ -12,6 +12,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/router/scheduler" + "github.com/mostlygeek/llama-swap/internal/shared" ) type shutdownReq struct { @@ -399,13 +400,13 @@ func (b *baseRouter) Shutdown(timeout time.Duration) error { func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { if b.shuttingDown.Load() { - SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) return } - data, err := FetchContext(req, b.config) + data, err := shared.FetchContext(req, b.config) if err != nil { - SendError(w, req, err) + shared.SendError(w, req, err) return } @@ -424,7 +425,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { case <-req.Context().Done(): return case <-b.shutdownCtx.Done(): - SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) return } @@ -475,12 +476,12 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { return case <-b.shutdownCtx.Done(): finishLoading() - SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) + shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name)) return } if resp.Err != nil { - SendError(w, req, resp.Err) + shared.SendError(w, req, resp.Err) return } resp.HandleFunc(w, req) diff --git a/internal/router/loading_test.go b/internal/router/loading_test.go index 57dc3bf6..6863708e 100644 --- a/internal/router/loading_test.go +++ b/internal/router/loading_test.go @@ -226,69 +226,6 @@ func TestIsLoadingPath(t *testing.T) { } } -func TestExtractContext_Streaming_GET(t *testing.T) { - tests := []struct { - name string - query string - wantStreaming bool - }{ - {"streaming true", "model=llama3&stream=true", true}, - {"streaming false", "model=llama3&stream=false", false}, - {"no stream param", "model=llama3", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) - got, err := ExtractContext(r) - if err != nil { - t.Fatalf("ExtractContext: %v", err) - } - if got.Streaming != tt.wantStreaming { - t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) - } - }) - } -} - -func TestExtractContext_Streaming_JSON(t *testing.T) { - tests := []struct { - name string - body string - wantStreaming bool - }{ - {"streaming true", `{"model":"llama3","stream":true}`, true}, - {"streaming false", `{"model":"llama3","stream":false}`, false}, - {"no stream param", `{"model":"llama3"}`, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) - r.Header.Set("Content-Type", "application/json") - got, err := ExtractContext(r) - if err != nil { - t.Fatalf("ExtractContext: %v", err) - } - if got.Streaming != tt.wantStreaming { - t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) - } - }) - } -} - -func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) { - r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true")) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - got, err := ExtractContext(r) - if err != nil { - t.Fatalf("ExtractContext: %v", err) - } - if !got.Streaming { - t.Error("Streaming should be true") - } -} - func countSSEMessages(s string) int { scanner := bufio.NewScanner(strings.NewReader(s)) count := 0 diff --git a/internal/router/peer.go b/internal/router/peer.go index fb17068c..e46447a6 100644 --- a/internal/router/peer.go +++ b/internal/router/peer.go @@ -15,6 +15,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/shared" ) type peerMember struct { @@ -146,22 +147,22 @@ func (r *Peer) Shutdown(timeout time.Duration) error { func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if r.shuttingDown.Load() { - SendError(w, req, fmt.Errorf("peer proxy is shutting down")) + shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down")) return } r.inflight.Add(1) defer r.inflight.Done() - data, err := FetchContext(req, r.cfg) + data, err := shared.FetchContext(req, r.cfg) if err != nil { - SendError(w, req, err) + shared.SendError(w, req, err) return } pp, found := r.peers[data.ModelID] if !found { r.logger.Warnf("peer model not found: %s", data.ModelID) - SendError(w, req, ErrNoPeerModelFound) + shared.SendError(w, req, ErrNoPeerModelFound) return } diff --git a/internal/router/peer_test.go b/internal/router/peer_test.go index 74527bcb..1641c420 100644 --- a/internal/router/peer_test.go +++ b/internal/router/peer_test.go @@ -12,6 +12,7 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/shared" ) var testLogger = logmon.NewWriter(os.Stdout) @@ -142,7 +143,7 @@ func TestPeer_ServeHTTP_Success(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -178,7 +179,7 @@ func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -212,7 +213,7 @@ func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -246,7 +247,7 @@ func TestPeer_ServeHTTP_NoApiKey(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -279,7 +280,7 @@ func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -311,7 +312,7 @@ func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -347,7 +348,7 @@ func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) @@ -385,7 +386,7 @@ func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) var wg sync.WaitGroup wg.Add(1) @@ -448,7 +449,7 @@ func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) { } req := httptest.NewRequest("POST", "/v1/chat/completions", nil) - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"})) var wg sync.WaitGroup wg.Add(1) @@ -551,7 +552,7 @@ func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) { body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`) req := httptest.NewRequest("POST", "/v1/chat/completions", body) req.Header.Set("Content-Type", "application/json") - *req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"})) + *req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"})) w := httptest.NewRecorder() pr.ServeHTTP(w, req) diff --git a/internal/router/router.go b/internal/router/router.go index 8ebc1f11..561c5fc8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,40 +1,18 @@ package router import ( - "bytes" - "context" - "errors" - "fmt" - "io" "net/http" - "strings" "time" - "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" - "github.com/mostlygeek/llama-swap/internal/router/scheduler" - "github.com/tidwall/gjson" + "github.com/mostlygeek/llama-swap/internal/shared" ) -type contextkey struct { - name string -} - -type ReqContextData struct { - Model string - ModelID string - Streaming bool - SendLoadingState bool -} - var ( - ErrNoModelInContext = fmt.Errorf("no model in request context") - ErrNoRouterFound = fmt.Errorf("no router found for model") - ErrNoPeerModelFound = fmt.Errorf("peer model not found") - ErrNoLocalModelFound = scheduler.ErrModelNotFound - - ContextKey = &contextkey{"context"} + ErrNoRouterFound = shared.ErrNoRouterFound + ErrNoPeerModelFound = shared.ErrNoPeerModelFound + ErrNoLocalModelFound = shared.ErrNoLocalModelFound ) type Router interface { @@ -72,129 +50,3 @@ type LocalRouter interface { // model is not known to this router. ProcessLogger(modelID string) (*logmon.Monitor, bool) } - -// FetchContext will attempt to get the model id from the context then -// from the model body. If it extracts the model from the body it will -// store the model in the context for downstream handlers. An error -// will be returned when model can not be fetch from either location. -func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { - data, ok := ReadContext(r.Context()) - if ok { - return data, nil - } - - if data, err := ExtractContext(r); err == nil { - realName, _ := cfg.RealModelName(data.Model) - if realName == "" { - realName = data.Model - } - data.ModelID = realName - if mc, ok := cfg.Models[realName]; ok { - data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState - } - *r = *r.WithContext(SetContext(r.Context(), data)) - return data, nil - } - - return ReqContextData{}, ErrNoModelInContext -} - -func SetContext(ctx context.Context, data ReqContextData) context.Context { - return context.WithValue(ctx, ContextKey, data) -} - -func ReadContext(ctx context.Context) (ReqContextData, bool) { - data, ok := ctx.Value(ContextKey).(ReqContextData) - return data, ok -} - -// ExtractContext pulls the model name from an HTTP request without consuming the -// body. For GET requests it reads the "model" query parameter. For POST -// requests it inspects Content-Type and parses JSON, multipart/form-data, or -// application/x-www-form-urlencoded bodies. The request body is always restored -// before returning so downstream handlers — including reverse proxies that -// forward raw bytes upstream — can still read it. -func ExtractContext(r *http.Request) (ReqContextData, error) { - if r.Method == http.MethodGet { - if model := r.URL.Query().Get("model"); model != "" { - return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil - } - return ReqContextData{}, fmt.Errorf("missing 'model' query parameter") - } - - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - return ReqContextData{}, fmt.Errorf("error reading request body: %w", err) - } - defer func() { - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - }() - - contentType := r.Header.Get("Content-Type") - - if strings.Contains(contentType, "application/json") { - model := gjson.GetBytes(bodyBytes, "model").String() - if model == "" { - return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body") - } - return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil - } - - // Form parsers read from r.Body, so feed them a fresh reader over the - // buffered bytes. The deferred restore above will reset r.Body again - // after parsing. - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - if strings.Contains(contentType, "multipart/form-data") { - if err := r.ParseMultipartForm(32 << 20); err != nil { - return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err) - } - } else { - if err := r.ParseForm(); err != nil { - return ReqContextData{}, fmt.Errorf("error parsing form: %w", err) - } - } - - if model := r.FormValue("model"); model != "" { - return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil - } - - return ReqContextData{}, fmt.Errorf("missing 'model' parameter") -} - -func SendError(w http.ResponseWriter, r *http.Request, err error) { - switch { - case errors.Is(err, ErrNoModelInContext): - SendResponse(w, r, http.StatusNotFound, "no model id could be identified") - case errors.Is(err, ErrNoPeerModelFound): - SendResponse(w, r, http.StatusNotFound, "no peer found for requested model") - case errors.Is(err, ErrNoLocalModelFound): - SendResponse(w, r, http.StatusNotFound, "no local server found for requested model") - case errors.Is(err, ErrNoRouterFound): - SendResponse(w, r, http.StatusNotFound, "no router for requested model") - default: - SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err)) - } -} - -// SendResponse detects what content type the client prefers and returns an error response in that format. -func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) { - // Check Accept header for preferred response format - acceptHeader := r.Header.Get("Accept") - if strings.Contains(acceptHeader, "text/plain") { - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf("llama-swap: %s", message))) - return - } - - if strings.Contains(acceptHeader, "text/html") { - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf(`
%s
`, message))) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message))) -} diff --git a/internal/router/scheduler/scheduler.go b/internal/router/scheduler/scheduler.go index 9dc9e281..87ed6ad2 100644 --- a/internal/router/scheduler/scheduler.go +++ b/internal/router/scheduler/scheduler.go @@ -11,17 +11,17 @@ package scheduler import ( "context" - "fmt" "net/http" "time" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" + "github.com/mostlygeek/llama-swap/internal/shared" ) // ErrModelNotFound is granted to callers whose model is not handled by this -// router. The router package aliases it so SendError can match it. -var ErrModelNotFound = fmt.Errorf("local model not found") +// router. It is an alias for shared.ErrNoLocalModelFound. +var ErrModelNotFound = shared.ErrNoLocalModelFound // Swapper is the eviction policy: it decides which running models must be // stopped before a target can serve. It is orthogonal to the scheduling diff --git a/internal/server/api.go b/internal/server/api.go index d3723c63..fe11ac2b 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -9,7 +9,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/event" - "github.com/mostlygeek/llama-swap/internal/router" "github.com/mostlygeek/llama-swap/internal/shared" ) @@ -163,7 +162,7 @@ func (s *Server) startPreload() { if err != nil { continue } - req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID})) + req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID})) dw := &discardResponseWriter{status: http.StatusOK} s.local.ServeHTTP(dw, req) @@ -208,7 +207,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath) if !found { - router.SendResponse(w, r, http.StatusNotFound, "model not found") + shared.SendResponse(w, r, http.StatusNotFound, "model not found") return } @@ -230,7 +229,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { // Strip the /upstream/%s
`, html.EscapeString(message)))) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message}) + if err != nil { + w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`)) + return + } + w.Write(resp) +} + +// FetchContext will attempt to get the model id from the context then +// from the model body. If it extracts the model from the body it will +// store the model in the context for downstream handlers. An error +// will be returned when model can not be fetch from either location. +func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { + data, ok := ReadContext(r.Context()) + if ok { + return data, nil + } + + if data, err := extractContext(r); err == nil && data.Model != "" { + realName, _ := cfg.RealModelName(data.Model) + if realName == "" { + realName = data.Model + } + data.ModelID = realName + if mc, ok := cfg.Models[realName]; ok { + data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState + } + *r = *r.WithContext(SetContext(r.Context(), data)) + return data, nil + } + + return ReqContextData{}, ErrNoModelInContext +} + +func SetContext(ctx context.Context, data ReqContextData) context.Context { + return context.WithValue(ctx, ReqContextKey, data) +} + +func ReadContext(ctx context.Context) (ReqContextData, bool) { + data, ok := ctx.Value(ReqContextKey).(ReqContextData) + return data, ok +} + +// extractContext pulls fields from an HTTP request into a ReqContextData, +// returning whatever is available. For GET requests it reads query parameters. +// For POST requests it inspects Content-Type and parses JSON, +// multipart/form-data, or application/x-www-form-urlencoded bodies. The +// request body is always restored before returning. An error is returned only +// for I/O or parse failures, not for missing fields. +func extractContext(r *http.Request) (ReqContextData, error) { + + apiKey := ExtractAPIKey(r) + + if r.Method == http.MethodGet { + q := r.URL.Query() + return ReqContextData{ + Model: q.Get("model"), + Streaming: q.Get("stream") == "true", + ApiKey: apiKey, + }, nil + } + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return ReqContextData{}, fmt.Errorf("error reading request body: %w", err) + } + defer func() { + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + }() + + contentType := r.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/json") { + return ReqContextData{ + Model: gjson.GetBytes(bodyBytes, "model").String(), + Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(), + ApiKey: apiKey, + }, nil + } + + // Form parsers read from r.Body, so feed them a fresh reader over the + // buffered bytes. The deferred restore above will reset r.Body again + // after parsing. + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + if strings.Contains(contentType, "multipart/form-data") { + if err := r.ParseMultipartForm(32 << 20); err != nil { + return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err) + } + } else { + if err := r.ParseForm(); err != nil { + return ReqContextData{}, fmt.Errorf("error parsing form: %w", err) + } + } + + return ReqContextData{ + Model: r.FormValue("model"), + Streaming: r.FormValue("stream") == "true", + ApiKey: apiKey, + }, nil +} + +// extractAPIKey pulls a candidate API key from the request, preferring Basic, +// then Bearer, then x-api-key. +func ExtractAPIKey(r *http.Request) string { + var bearerKey, basicKey string + if auth := r.Header.Get("Authorization"); auth != "" { + scheme, credentials, ok := strings.Cut(auth, " ") + if ok { + switch strings.ToLower(scheme) { + case "bearer": + bearerKey = credentials + case "basic": + if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil { + if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 { + basicKey = parts[1] // password field is the API key + } + } + } + } + } + + switch { + case basicKey != "": + return basicKey + case bearerKey != "": + return bearerKey + default: + return r.Header.Get("x-api-key") + } +} diff --git a/internal/router/router_test.go b/internal/shared/http_test.go similarity index 56% rename from internal/router/router_test.go rename to internal/shared/http_test.go index fa88364c..8bbdd69c 100644 --- a/internal/router/router_test.go +++ b/internal/shared/http_test.go @@ -1,11 +1,13 @@ -package router +package shared import ( "bytes" "context" + "encoding/base64" "io" "mime/multipart" "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -20,13 +22,13 @@ func TestExtractContext_GET(t *testing.T) { }{ {"model present", "model=llama3", "llama3", false}, {"model with slashes", "model=author/model-7b", "author/model-7b", false}, - {"model missing", "", "", true}, + {"model missing", "", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -46,16 +48,16 @@ func TestExtractContext_JSON(t *testing.T) { }{ {"model present", `{"model":"llama3","stream":true}`, "llama3", false}, {"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false}, - {"model empty string", `{"model":""}`, "", true}, - {"model key missing", `{"stream":true}`, "", true}, - {"invalid json", `not-json`, "", true}, + {"model empty string", `{"model":""}`, "", false}, + {"model key missing", `{"stream":true}`, "", false}, + {"invalid json", `not-json`, "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) r.Header.Set("Content-Type", "application/json") - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -74,7 +76,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) { wantErr bool }{ {"model present", "whisper-1", "whisper-1", false}, - {"model missing", "", "", true}, + {"model missing", "", "", false}, } for _, tt := range tests { @@ -85,7 +87,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) { } r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -104,7 +106,7 @@ func TestExtractContext_MultipartForm(t *testing.T) { wantErr bool }{ {"model present", "whisper-1", "whisper-1", false}, - {"model missing", "", "", true}, + {"model missing", "", "", false}, } for _, tt := range tests { @@ -119,7 +121,7 @@ func TestExtractContext_MultipartForm(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf) r.Header.Set("Content-Type", mw.FormDataContentType()) - got, err := ExtractContext(r) + got, err := extractContext(r) if (err != nil) != tt.wantErr { t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err) } @@ -135,7 +137,7 @@ func TestExtractContext_JSONBodyRestored(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) r.Header.Set("Content-Type", "application/json") - if _, err := ExtractContext(r); err != nil { + if _, err := extractContext(r); err != nil { t.Fatalf("ExtractContext: %v", err) } @@ -162,7 +164,7 @@ func TestExtractContext_MultipartBodyRestored(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original)) r.Header.Set("Content-Type", mw.FormDataContentType()) - if _, err := ExtractContext(r); err != nil { + if _, err := extractContext(r); err != nil { t.Fatalf("ExtractContext: %v", err) } @@ -180,7 +182,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) { r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if _, err := ExtractContext(r); err != nil { + if _, err := extractContext(r); err != nil { t.Fatalf("ExtractContext: %v", err) } @@ -195,7 +197,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) { func TestSetContext(t *testing.T) { ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}) - data, ok := ctx.Value(ContextKey).(ReqContextData) + data, ok := ctx.Value(ReqContextKey).(ReqContextData) if !ok { t.Fatalf("ContextKey not set or wrong type") } @@ -209,7 +211,7 @@ func TestSetContext(t *testing.T) { func TestSetContext_WithAlias(t *testing.T) { ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}) - data, _ := ctx.Value(ContextKey).(ReqContextData) + data, _ := ctx.Value(ReqContextKey).(ReqContextData) if data.Model != "llama" { t.Errorf("want requested %q got %q", "llama", data.Model) } @@ -221,7 +223,7 @@ func TestSetContext_WithAlias(t *testing.T) { func TestSetContext_DoesNotMutateParent(t *testing.T) { parent := context.Background() _ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"}) - if v := parent.Value(ContextKey); v != nil { + if v := parent.Value(ReqContextKey); v != nil { t.Errorf("parent context was mutated: %v", v) } } @@ -273,3 +275,152 @@ func TestReadContext(t *testing.T) { }) } } + +func TestExtractContext_Streaming_GET(t *testing.T) { + tests := []struct { + name string + query string + wantStreaming bool + }{ + {"streaming true", "model=llama3&stream=true", true}, + {"streaming false", "model=llama3&stream=false", false}, + {"no stream param", "model=llama3", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil) + got, err := extractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if got.Streaming != tt.wantStreaming { + t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) + } + }) + } +} + +func TestExtractContext_Streaming_JSON(t *testing.T) { + tests := []struct { + name string + body string + wantStreaming bool + }{ + {"streaming true", `{"model":"llama3","stream":true}`, true}, + {"streaming false", `{"model":"llama3","stream":false}`, false}, + {"no stream param", `{"model":"llama3"}`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body)) + r.Header.Set("Content-Type", "application/json") + got, err := extractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if got.Streaming != tt.wantStreaming { + t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming) + } + }) + } +} + +func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := extractContext(r) + if err != nil { + t.Fatalf("ExtractContext: %v", err) + } + if !got.Streaming { + t.Error("Streaming should be true") + } +} + +func TestExtractContext_ApiKey(t *testing.T) { + basicHeader := func(user, pass string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) + } + cases := []struct { + name string + method string + ct string + body string + auth string + xapi string + wantKey string + }{ + {"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"}, + {"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"}, + {"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"}, + {"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"}, + {"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"}, + {"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"}, + {"no key", http.MethodGet, "", "", "", "", ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var body io.Reader + if c.body != "" { + body = strings.NewReader(c.body) + } + r, _ := http.NewRequest(c.method, "/", body) + if c.ct != "" { + r.Header.Set("Content-Type", c.ct) + } + if c.auth != "" { + r.Header.Set("Authorization", c.auth) + } + if c.xapi != "" { + r.Header.Set("x-api-key", c.xapi) + } + got, err := extractContext(r) + if err != nil { + t.Fatalf("extractContext: %v", err) + } + if got.ApiKey != c.wantKey { + t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey) + } + }) + } +} + +func TestServer_ExtractAPIKey(t *testing.T) { + basicHeader := func(user, pass string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) + } + cases := []struct { + name string + auth string + xapi string + want string + }{ + {"none", "", "", ""}, + {"bearer", "Bearer tok123", "", "tok123"}, + {"basic", basicHeader("user", "pw-key"), "", "pw-key"}, + {"x-api-key", "", "xkey", "xkey"}, + {"basic beats bearer", basicHeader("u", "bk"), "", "bk"}, + {"bearer beats x-api-key", "Bearer btok", "xkey", "btok"}, + {"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"}, + {"lowercase bearer", "bearer tok123", "", "tok123"}, + {"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"}, + {"mixed case BEARER", "BEARER tok456", "", "tok456"}, + {"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + if c.auth != "" { + r.Header.Set("Authorization", c.auth) + } + if c.xapi != "" { + r.Header.Set("x-api-key", c.xapi) + } + if got := ExtractAPIKey(r); got != c.want { + t.Errorf("extractAPIKey() = %q, want %q", got, c.want) + } + }) + } +}