internal/router,server,shared: refactor auth, libs (#839)

- refactor shared http functionality into internal/shared/http.go
- remove stripping of Authorization and x-api-key
- add Request Context middleware to internal/server
- add /ui and /metrics behind auth middleware, fixes #717

Fix #717
Updates: #834
This commit is contained in:
Benson Wong
2026-06-13 10:19:04 -07:00
committed by GitHub
parent 8c660dcb90
commit 62aea0e83d
18 changed files with 497 additions and 377 deletions
+202
View File
@@ -0,0 +1,202 @@
package shared
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"net/http"
"strings"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/tidwall/gjson"
)
type contextkey struct {
name string
}
type ReqContextData struct {
ApiKey string
Model string
ModelID string
Streaming bool
SendLoadingState bool
}
var (
ReqContextKey = &contextkey{"context"}
ErrNoModelInContext = fmt.Errorf("no model in request context")
ErrNoRouterFound = fmt.Errorf("no router found for model")
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
ErrNoLocalModelFound = fmt.Errorf("local model not found")
)
func SendError(w http.ResponseWriter, r *http.Request, err error) {
switch {
case errors.Is(err, ErrNoModelInContext):
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
case errors.Is(err, ErrNoPeerModelFound):
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
case errors.Is(err, ErrNoLocalModelFound):
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
case errors.Is(err, ErrNoRouterFound):
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
default:
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
}
}
// SendResponse detects what content type the client prefers and returns an error response in that format.
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
acceptHeader := r.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/plain") {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
return
}
if strings.Contains(acceptHeader, "text/html") {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, html.EscapeString(message))))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message})
if err != nil {
w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`))
return
}
w.Write(resp)
}
// FetchContext will attempt to get the model id from the context then
// from the model body. If it extracts the model from the body it will
// store the model in the context for downstream handlers. An error
// will be returned when model can not be fetch from either location.
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
data, ok := ReadContext(r.Context())
if ok {
return data, nil
}
if data, err := extractContext(r); err == nil && data.Model != "" {
realName, _ := cfg.RealModelName(data.Model)
if realName == "" {
realName = data.Model
}
data.ModelID = realName
if mc, ok := cfg.Models[realName]; ok {
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
}
*r = *r.WithContext(SetContext(r.Context(), data))
return data, nil
}
return ReqContextData{}, ErrNoModelInContext
}
func SetContext(ctx context.Context, data ReqContextData) context.Context {
return context.WithValue(ctx, ReqContextKey, data)
}
func ReadContext(ctx context.Context) (ReqContextData, bool) {
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
return data, ok
}
// extractContext pulls fields from an HTTP request into a ReqContextData,
// returning whatever is available. For GET requests it reads query parameters.
// For POST requests it inspects Content-Type and parses JSON,
// multipart/form-data, or application/x-www-form-urlencoded bodies. The
// request body is always restored before returning. An error is returned only
// for I/O or parse failures, not for missing fields.
func extractContext(r *http.Request) (ReqContextData, error) {
apiKey := ExtractAPIKey(r)
if r.Method == http.MethodGet {
q := r.URL.Query()
return ReqContextData{
Model: q.Get("model"),
Streaming: q.Get("stream") == "true",
ApiKey: apiKey,
}, nil
}
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
}
defer func() {
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}()
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
return ReqContextData{
Model: gjson.GetBytes(bodyBytes, "model").String(),
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
ApiKey: apiKey,
}, nil
}
// Form parsers read from r.Body, so feed them a fresh reader over the
// buffered bytes. The deferred restore above will reset r.Body again
// after parsing.
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if strings.Contains(contentType, "multipart/form-data") {
if err := r.ParseMultipartForm(32 << 20); err != nil {
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
}
} else {
if err := r.ParseForm(); err != nil {
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
}
}
return ReqContextData{
Model: r.FormValue("model"),
Streaming: r.FormValue("stream") == "true",
ApiKey: apiKey,
}, nil
}
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
// then Bearer, then x-api-key.
func ExtractAPIKey(r *http.Request) string {
var bearerKey, basicKey string
if auth := r.Header.Get("Authorization"); auth != "" {
scheme, credentials, ok := strings.Cut(auth, " ")
if ok {
switch strings.ToLower(scheme) {
case "bearer":
bearerKey = credentials
case "basic":
if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil {
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
basicKey = parts[1] // password field is the API key
}
}
}
}
}
switch {
case basicKey != "":
return basicKey
case bearerKey != "":
return bearerKey
default:
return r.Header.Get("x-api-key")
}
}
+426
View File
@@ -0,0 +1,426 @@
package shared
import (
"bytes"
"context"
"encoding/base64"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestExtractContext_GET(t *testing.T) {
tests := []struct {
name string
query string
wantModel string
wantErr bool
}{
{"model present", "model=llama3", "llama3", false},
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
{"model missing", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
got, err := extractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_JSON(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantErr bool
}{
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
{"model empty string", `{"model":""}`, "", false},
{"model key missing", `{"stream":true}`, "", false},
{"invalid json", `not-json`, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
got, err := extractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_URLEncodedForm(t *testing.T) {
tests := []struct {
name string
formModel string
wantModel string
wantErr bool
}{
{"model present", "whisper-1", "whisper-1", false},
{"model missing", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
form := url.Values{}
if tt.formModel != "" {
form.Set("model", tt.formModel)
}
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := extractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_MultipartForm(t *testing.T) {
tests := []struct {
name string
formModel string
wantModel string
wantErr bool
}{
{"model present", "whisper-1", "whisper-1", false},
{"model missing", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
if tt.formModel != "" {
fw, _ := mw.CreateFormField("model")
fw.Write([]byte(tt.formModel))
}
mw.Close()
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
r.Header.Set("Content-Type", mw.FormDataContentType())
got, err := extractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_JSONBodyRestored(t *testing.T) {
body := `{"model":"llama3","stream":true}`
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
if _, err := extractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if string(remaining) != body {
t.Errorf("body not restored: want %q got %q", body, string(remaining))
}
}
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
fw, _ := mw.CreateFormField("model")
fw.Write([]byte("whisper-1"))
ff, _ := mw.CreateFormFile("file", "audio.wav")
ff.Write([]byte("fake-audio-bytes"))
mw.Close()
original := buf.Bytes()
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
r.Header.Set("Content-Type", mw.FormDataContentType())
if _, err := extractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if !bytes.Equal(remaining, original) {
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
}
}
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
body := "model=whisper-1&extra=value"
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if _, err := extractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if string(remaining) != body {
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
}
}
func TestSetContext(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
if !ok {
t.Fatalf("ContextKey not set or wrong type")
}
if data.Model != "llama3" {
t.Errorf("want %q got %q", "llama3", data.Model)
}
if data.ModelID != "llama3" {
t.Errorf("want %q got %q", "llama3", data.ModelID)
}
}
func TestSetContext_WithAlias(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
data, _ := ctx.Value(ReqContextKey).(ReqContextData)
if data.Model != "llama" {
t.Errorf("want requested %q got %q", "llama", data.Model)
}
if data.ModelID != "llama3" {
t.Errorf("want real %q got %q", "llama3", data.ModelID)
}
}
func TestSetContext_DoesNotMutateParent(t *testing.T) {
parent := context.Background()
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
if v := parent.Value(ReqContextKey); v != nil {
t.Errorf("parent context was mutated: %v", v)
}
}
func TestReadContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
wantReq string
wantReal string
wantBool bool
}{
{
name: "model present, same name",
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
wantReq: "llama3",
wantReal: "llama3",
wantBool: true,
},
{
name: "model present, aliased",
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
wantReq: "llama",
wantReal: "llama3",
wantBool: true,
},
{
name: "model absent",
ctx: context.Background(),
wantReq: "",
wantReal: "",
wantBool: false,
},
{
name: "model is empty string",
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
wantReq: "",
wantReal: "",
wantBool: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotData, ok := ReadContext(tt.ctx)
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
}
})
}
}
func TestExtractContext_Streaming_GET(t *testing.T) {
tests := []struct {
name string
query string
wantStreaming bool
}{
{"streaming true", "model=llama3&stream=true", true},
{"streaming false", "model=llama3&stream=false", false},
{"no stream param", "model=llama3", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
got, err := extractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if got.Streaming != tt.wantStreaming {
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
}
})
}
}
func TestExtractContext_Streaming_JSON(t *testing.T) {
tests := []struct {
name string
body string
wantStreaming bool
}{
{"streaming true", `{"model":"llama3","stream":true}`, true},
{"streaming false", `{"model":"llama3","stream":false}`, false},
{"no stream param", `{"model":"llama3"}`, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
got, err := extractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if got.Streaming != tt.wantStreaming {
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
}
})
}
}
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := extractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if !got.Streaming {
t.Error("Streaming should be true")
}
}
func TestExtractContext_ApiKey(t *testing.T) {
basicHeader := func(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
}
cases := []struct {
name string
method string
ct string
body string
auth string
xapi string
wantKey string
}{
{"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"},
{"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"},
{"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"},
{"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"},
{"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"},
{"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"},
{"no key", http.MethodGet, "", "", "", "", ""},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var body io.Reader
if c.body != "" {
body = strings.NewReader(c.body)
}
r, _ := http.NewRequest(c.method, "/", body)
if c.ct != "" {
r.Header.Set("Content-Type", c.ct)
}
if c.auth != "" {
r.Header.Set("Authorization", c.auth)
}
if c.xapi != "" {
r.Header.Set("x-api-key", c.xapi)
}
got, err := extractContext(r)
if err != nil {
t.Fatalf("extractContext: %v", err)
}
if got.ApiKey != c.wantKey {
t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey)
}
})
}
}
func TestServer_ExtractAPIKey(t *testing.T) {
basicHeader := func(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
}
cases := []struct {
name string
auth string
xapi string
want string
}{
{"none", "", "", ""},
{"bearer", "Bearer tok123", "", "tok123"},
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
{"x-api-key", "", "xkey", "xkey"},
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
{"lowercase bearer", "bearer tok123", "", "tok123"},
{"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"},
{"mixed case BEARER", "BEARER tok456", "", "tok456"},
{"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if c.auth != "" {
r.Header.Set("Authorization", c.auth)
}
if c.xapi != "" {
r.Header.Set("x-api-key", c.xapi)
}
if got := ExtractAPIKey(r); got != c.want {
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
}
})
}
}