internal/router,server,shared: refactor auth, libs (#839)
- refactor shared http functionality into internal/shared/http.go - remove stripping of Authorization and x-api-key - add Request Context middleware to internal/server - add /ui and /metrics behind auth middleware, fixes #717 Fix #717 Updates: #834
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type contextkey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type ReqContextData struct {
|
||||
ApiKey string
|
||||
Model string
|
||||
ModelID string
|
||||
Streaming bool
|
||||
SendLoadingState bool
|
||||
}
|
||||
|
||||
var (
|
||||
ReqContextKey = &contextkey{"context"}
|
||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||
)
|
||||
|
||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ErrNoModelInContext):
|
||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||
case errors.Is(err, ErrNoPeerModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||
case errors.Is(err, ErrNoLocalModelFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
case errors.Is(err, ErrNoRouterFound):
|
||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||
default:
|
||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
acceptHeader := r.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/plain") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(acceptHeader, "text/html") {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, html.EscapeString(message))))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message})
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`))
|
||||
return
|
||||
}
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
// FetchContext will attempt to get the model id from the context then
|
||||
// from the model body. If it extracts the model from the body it will
|
||||
// store the model in the context for downstream handlers. An error
|
||||
// will be returned when model can not be fetch from either location.
|
||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||
data, ok := ReadContext(r.Context())
|
||||
if ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||
realName, _ := cfg.RealModelName(data.Model)
|
||||
if realName == "" {
|
||||
realName = data.Model
|
||||
}
|
||||
data.ModelID = realName
|
||||
if mc, ok := cfg.Models[realName]; ok {
|
||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||
}
|
||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return ReqContextData{}, ErrNoModelInContext
|
||||
}
|
||||
|
||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||
return context.WithValue(ctx, ReqContextKey, data)
|
||||
}
|
||||
|
||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// extractContext pulls fields from an HTTP request into a ReqContextData,
|
||||
// returning whatever is available. For GET requests it reads query parameters.
|
||||
// For POST requests it inspects Content-Type and parses JSON,
|
||||
// multipart/form-data, or application/x-www-form-urlencoded bodies. The
|
||||
// request body is always restored before returning. An error is returned only
|
||||
// for I/O or parse failures, not for missing fields.
|
||||
func extractContext(r *http.Request) (ReqContextData, error) {
|
||||
|
||||
apiKey := ExtractAPIKey(r)
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
q := r.URL.Query()
|
||||
return ReqContextData{
|
||||
Model: q.Get("model"),
|
||||
Streaming: q.Get("stream") == "true",
|
||||
ApiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}()
|
||||
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ReqContextData{
|
||||
Model: gjson.GetBytes(bodyBytes, "model").String(),
|
||||
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
ApiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||
// buffered bytes. The deferred restore above will reset r.Body again
|
||||
// after parsing.
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return ReqContextData{
|
||||
Model: r.FormValue("model"),
|
||||
Streaming: r.FormValue("stream") == "true",
|
||||
ApiKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func ExtractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
scheme, credentials, ok := strings.Cut(auth, " ")
|
||||
if ok {
|
||||
switch strings.ToLower(scheme) {
|
||||
case "bearer":
|
||||
bearerKey = credentials
|
||||
case "basic":
|
||||
if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,426 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractContext_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "model=llama3", "llama3", false},
|
||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||
{"model empty string", `{"model":""}`, "", false},
|
||||
{"model key missing", `{"stream":true}`, "", false},
|
||||
{"invalid json", `not-json`, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
form := url.Values{}
|
||||
if tt.formModel != "" {
|
||||
form.Set("model", tt.formModel)
|
||||
}
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formModel string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{"model present", "whisper-1", "whisper-1", false},
|
||||
{"model missing", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
if tt.formModel != "" {
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte(tt.formModel))
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
got, err := extractContext(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||
}
|
||||
if got.Model != tt.wantModel {
|
||||
t.Errorf("want %q got %q", tt.wantModel, got.Model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
||||
body := `{"model":"llama3","stream":true}`
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
fw, _ := mw.CreateFormField("model")
|
||||
fw.Write([]byte("whisper-1"))
|
||||
ff, _ := mw.CreateFormFile("file", "audio.wav")
|
||||
ff.Write([]byte("fake-audio-bytes"))
|
||||
mw.Close()
|
||||
|
||||
original := buf.Bytes()
|
||||
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if !bytes.Equal(remaining, original) {
|
||||
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
||||
body := "model=whisper-1&extra=value"
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
if _, err := extractContext(r); err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
|
||||
remaining, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body after ExtractContext: %v", err)
|
||||
}
|
||||
if string(remaining) != body {
|
||||
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
if !ok {
|
||||
t.Fatalf("ContextKey not set or wrong type")
|
||||
}
|
||||
if data.Model != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_WithAlias(t *testing.T) {
|
||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||
data, _ := ctx.Value(ReqContextKey).(ReqContextData)
|
||||
if data.Model != "llama" {
|
||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||
}
|
||||
if data.ModelID != "llama3" {
|
||||
t.Errorf("want real %q got %q", "llama3", data.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||
parent := context.Background()
|
||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||
if v := parent.Value(ReqContextKey); v != nil {
|
||||
t.Errorf("parent context was mutated: %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
wantReq string
|
||||
wantReal string
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "model present, same name",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
|
||||
wantReq: "llama3",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model present, aliased",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
|
||||
wantReq: "llama",
|
||||
wantReal: "llama3",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "model absent",
|
||||
ctx: context.Background(),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "model is empty string",
|
||||
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
|
||||
wantReq: "",
|
||||
wantReal: "",
|
||||
wantBool: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotData, ok := ReadContext(tt.ctx)
|
||||
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
|
||||
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", "model=llama3&stream=true", true},
|
||||
{"streaming false", "model=llama3&stream=false", false},
|
||||
{"no stream param", "model=llama3", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStreaming bool
|
||||
}{
|
||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||
{"no stream param", `{"model":"llama3"}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if got.Streaming != tt.wantStreaming {
|
||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractContext: %v", err)
|
||||
}
|
||||
if !got.Streaming {
|
||||
t.Error("Streaming should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContext_ApiKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
ct string
|
||||
body string
|
||||
auth string
|
||||
xapi string
|
||||
wantKey string
|
||||
}{
|
||||
{"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"},
|
||||
{"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"},
|
||||
{"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"},
|
||||
{"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"},
|
||||
{"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"},
|
||||
{"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"},
|
||||
{"no key", http.MethodGet, "", "", "", "", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
var body io.Reader
|
||||
if c.body != "" {
|
||||
body = strings.NewReader(c.body)
|
||||
}
|
||||
r, _ := http.NewRequest(c.method, "/", body)
|
||||
if c.ct != "" {
|
||||
r.Header.Set("Content-Type", c.ct)
|
||||
}
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
got, err := extractContext(r)
|
||||
if err != nil {
|
||||
t.Fatalf("extractContext: %v", err)
|
||||
}
|
||||
if got.ApiKey != c.wantKey {
|
||||
t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
{"lowercase bearer", "bearer tok123", "", "tok123"},
|
||||
{"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"},
|
||||
{"mixed case BEARER", "BEARER tok456", "", "tok456"},
|
||||
{"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := ExtractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user