package server import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/mostlygeek/llama-swap/internal/config" ) func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) { cases := []struct { in string want string }{ {"Content-Type, Authorization", "Content-Type, Authorization"}, {" X-Custom , Accept ", "X-Custom, Accept"}, {"Valid, Bad Header", "Valid"}, {"Bad@Header", ""}, {"", ""}, } for _, c := range cases { if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want { t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want) } } } func TestServer_IsTokenChar(t *testing.T) { for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" { if !isTokenChar(r) { t.Errorf("isTokenChar(%q) = false, want true", r) } } for _, r := range " @()/\t\"" { if isTokenChar(r) { t.Errorf("isTokenChar(%q) = true, want false", r) } } } func TestServer_RequestContextMiddleware(t *testing.T) { cfg := config.Config{ Models: map[string]config.ModelConfig{ "llama3": {}, }, } final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := CreateRequestContextMiddleware(cfg) t.Run("known model passes through", func(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"llama3"}`)) r.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() mw(final).ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } }) t.Run("missing model returns 404", func(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`)) r.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() mw(final).ServeHTTP(w, r) if w.Code != http.StatusNotFound { t.Errorf("status = %d, want 404", w.Code) } }) } func TestServer_AuthMiddleware(t *testing.T) { final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("no keys configured passes through", func(t *testing.T) { mw := CreateAuthMiddleware(config.Config{}) w := httptest.NewRecorder() mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } }) cfg := config.Config{RequiredAPIKeys: []string{"secret"}} t.Run("valid key", func(t *testing.T) { mw := CreateAuthMiddleware(cfg) r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Authorization", "Bearer secret") w := httptest.NewRecorder() mw(final).ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } }) t.Run("invalid key", func(t *testing.T) { mw := CreateAuthMiddleware(cfg) r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Authorization", "Bearer wrong") w := httptest.NewRecorder() mw(final).ServeHTTP(w, r) if w.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", w.Code) } if w.Header().Get("WWW-Authenticate") == "" { t.Error("missing WWW-Authenticate header") } }) }