package server import ( "encoding/base64" "net/http" "strings" "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/router" ) // CreateAuthMiddleware returns middleware that validates API keys when the // config declares any. It accepts the key via Authorization: Bearer, // Authorization: Basic (password field), or x-api-key. On success the auth // headers are stripped so they never leak to upstream. When no keys are // configured the middleware is a pass-through. func CreateAuthMiddleware(cfg config.Config) chain.Middleware { keys := cfg.RequiredAPIKeys return func(next http.Handler) http.Handler { if len(keys) == 0 { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { provided := extractAPIKey(r) valid := false for _, key := range keys { if provided == key { valid = true break } } if !valid { w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`) router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key") return } r.Header.Del("Authorization") r.Header.Del("x-api-key") next.ServeHTTP(w, r) }) } } // extractAPIKey pulls a candidate API key from the request, preferring Basic, // then Bearer, then x-api-key. func extractAPIKey(r *http.Request) string { var bearerKey, basicKey string if auth := r.Header.Get("Authorization"); auth != "" { if strings.HasPrefix(auth, "Bearer ") { bearerKey = strings.TrimPrefix(auth, "Bearer ") } else if strings.HasPrefix(auth, "Basic ") { encoded := strings.TrimPrefix(auth, "Basic ") if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil { if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 { basicKey = parts[1] // password field is the API key } } } } switch { case basicKey != "": return basicKey case bearerKey != "": return bearerKey default: return r.Header.Get("x-api-key") } } // CreateCORSMiddleware returns middleware that answers OPTIONS preflight // requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS // requests pass through untouched. func CreateCORSMiddleware() chain.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodOptions { next.ServeHTTP(w, r) return } w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" { w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers)) } else { w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With") } w.Header().Set("Access-Control-Max-Age", "86400") w.WriteHeader(http.StatusNoContent) }) } } func isTokenChar(r rune) bool { switch { case r >= 'a' && r <= 'z': case r >= 'A' && r <= 'Z': case r >= '0' && r <= '9': case strings.ContainsRune("!#$%&'*+-.^_`|~", r): default: return false } return true } // sanitizeAccessControlRequestHeaderValues drops any header names that contain // characters outside the HTTP token grammar before echoing them back. func sanitizeAccessControlRequestHeaderValues(headerValues string) string { parts := strings.Split(headerValues, ",") valid := make([]string, 0, len(parts)) for _, p := range parts { v := strings.TrimSpace(p) if v == "" { continue } validPart := true for _, c := range v { if !isTokenChar(c) { validPart = false break } } if validPart { valid = append(valid, v) } } return strings.Join(valid, ", ") }