internal/router,server,shared: refactor auth, libs (#839)
- refactor shared http functionality into internal/shared/http.go - remove stripping of Authorization and x-api-key - add Request Context middleware to internal/server - add /ui and /metrics behind auth middleware, fixes #717 Fix #717 Updates: #834
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
type shutdownReq struct {
|
type shutdownReq struct {
|
||||||
@@ -399,13 +400,13 @@ func (b *baseRouter) Shutdown(timeout time.Duration) error {
|
|||||||
|
|
||||||
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
if b.shuttingDown.Load() {
|
if b.shuttingDown.Load() {
|
||||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := FetchContext(req, b.config)
|
data, err := shared.FetchContext(req, b.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SendError(w, req, err)
|
shared.SendError(w, req, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,7 +425,7 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
case <-req.Context().Done():
|
case <-req.Context().Done():
|
||||||
return
|
return
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,12 +476,12 @@ func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
return
|
return
|
||||||
case <-b.shutdownCtx.Done():
|
case <-b.shutdownCtx.Done():
|
||||||
finishLoading()
|
finishLoading()
|
||||||
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Err != nil {
|
if resp.Err != nil {
|
||||||
SendError(w, req, resp.Err)
|
shared.SendError(w, req, resp.Err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.HandleFunc(w, req)
|
resp.HandleFunc(w, req)
|
||||||
|
|||||||
@@ -226,69 +226,6 @@ func TestIsLoadingPath(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractContext_Streaming_GET(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
wantStreaming bool
|
|
||||||
}{
|
|
||||||
{"streaming true", "model=llama3&stream=true", true},
|
|
||||||
{"streaming false", "model=llama3&stream=false", false},
|
|
||||||
{"no stream param", "model=llama3", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if got.Streaming != tt.wantStreaming {
|
|
||||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body string
|
|
||||||
wantStreaming bool
|
|
||||||
}{
|
|
||||||
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
|
||||||
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
|
||||||
{"no stream param", `{"model":"llama3"}`, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
|
||||||
r.Header.Set("Content-Type", "application/json")
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if got.Streaming != tt.wantStreaming {
|
|
||||||
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
got, err := ExtractContext(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
|
||||||
}
|
|
||||||
if !got.Streaming {
|
|
||||||
t.Error("Streaming should be true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func countSSEMessages(s string) int {
|
func countSSEMessages(s string) int {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
count := 0
|
count := 0
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
type peerMember struct {
|
type peerMember struct {
|
||||||
@@ -146,22 +147,22 @@ func (r *Peer) Shutdown(timeout time.Duration) error {
|
|||||||
|
|
||||||
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
if r.shuttingDown.Load() {
|
if r.shuttingDown.Load() {
|
||||||
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.inflight.Add(1)
|
r.inflight.Add(1)
|
||||||
defer r.inflight.Done()
|
defer r.inflight.Done()
|
||||||
|
|
||||||
data, err := FetchContext(req, r.cfg)
|
data, err := shared.FetchContext(req, r.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SendError(w, req, err)
|
shared.SendError(w, req, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pp, found := r.peers[data.ModelID]
|
pp, found := r.peers[data.ModelID]
|
||||||
if !found {
|
if !found {
|
||||||
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
||||||
SendError(w, req, ErrNoPeerModelFound)
|
shared.SendError(w, req, ErrNoPeerModelFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testLogger = logmon.NewWriter(os.Stdout)
|
var testLogger = logmon.NewWriter(os.Stdout)
|
||||||
@@ -142,7 +143,7 @@ func TestPeer_ServeHTTP_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -178,7 +179,7 @@ func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -212,7 +213,7 @@ func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -246,7 +247,7 @@ func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -279,7 +280,7 @@ func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -311,7 +312,7 @@ func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -347,7 +348,7 @@ func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
@@ -385,7 +386,7 @@ func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -448,7 +449,7 @@ func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -551,7 +552,7 @@ func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
|
|||||||
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
pr.ServeHTTP(w, req)
|
pr.ServeHTTP(w, req)
|
||||||
|
|||||||
+4
-152
@@ -1,40 +1,18 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextkey struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ReqContextData struct {
|
|
||||||
Model string
|
|
||||||
ModelID string
|
|
||||||
Streaming bool
|
|
||||||
SendLoadingState bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
ErrNoRouterFound = shared.ErrNoRouterFound
|
||||||
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
ErrNoPeerModelFound = shared.ErrNoPeerModelFound
|
||||||
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
ErrNoLocalModelFound = shared.ErrNoLocalModelFound
|
||||||
ErrNoLocalModelFound = scheduler.ErrModelNotFound
|
|
||||||
|
|
||||||
ContextKey = &contextkey{"context"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
@@ -72,129 +50,3 @@ type LocalRouter interface {
|
|||||||
// model is not known to this router.
|
// model is not known to this router.
|
||||||
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
ProcessLogger(modelID string) (*logmon.Monitor, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchContext will attempt to get the model id from the context then
|
|
||||||
// from the model body. If it extracts the model from the body it will
|
|
||||||
// store the model in the context for downstream handlers. An error
|
|
||||||
// will be returned when model can not be fetch from either location.
|
|
||||||
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
|
||||||
data, ok := ReadContext(r.Context())
|
|
||||||
if ok {
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if data, err := ExtractContext(r); err == nil {
|
|
||||||
realName, _ := cfg.RealModelName(data.Model)
|
|
||||||
if realName == "" {
|
|
||||||
realName = data.Model
|
|
||||||
}
|
|
||||||
data.ModelID = realName
|
|
||||||
if mc, ok := cfg.Models[realName]; ok {
|
|
||||||
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
|
||||||
}
|
|
||||||
*r = *r.WithContext(SetContext(r.Context(), data))
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ReqContextData{}, ErrNoModelInContext
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
|
||||||
return context.WithValue(ctx, ContextKey, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
|
||||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
|
||||||
return data, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractContext pulls the model name from an HTTP request without consuming the
|
|
||||||
// body. For GET requests it reads the "model" query parameter. For POST
|
|
||||||
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
|
|
||||||
// application/x-www-form-urlencoded bodies. The request body is always restored
|
|
||||||
// before returning so downstream handlers — including reverse proxies that
|
|
||||||
// forward raw bytes upstream — can still read it.
|
|
||||||
func ExtractContext(r *http.Request) (ReqContextData, error) {
|
|
||||||
if r.Method == http.MethodGet {
|
|
||||||
if model := r.URL.Query().Get("model"); model != "" {
|
|
||||||
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
|
|
||||||
}
|
|
||||||
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
}()
|
|
||||||
|
|
||||||
contentType := r.Header.Get("Content-Type")
|
|
||||||
|
|
||||||
if strings.Contains(contentType, "application/json") {
|
|
||||||
model := gjson.GetBytes(bodyBytes, "model").String()
|
|
||||||
if model == "" {
|
|
||||||
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
|
|
||||||
}
|
|
||||||
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Form parsers read from r.Body, so feed them a fresh reader over the
|
|
||||||
// buffered bytes. The deferred restore above will reset r.Body again
|
|
||||||
// after parsing.
|
|
||||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
if strings.Contains(contentType, "multipart/form-data") {
|
|
||||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
|
||||||
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if model := r.FormValue("model"); model != "" {
|
|
||||||
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
|
|
||||||
}
|
|
||||||
|
|
||||||
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
switch {
|
|
||||||
case errors.Is(err, ErrNoModelInContext):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
|
||||||
case errors.Is(err, ErrNoPeerModelFound):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
|
||||||
case errors.Is(err, ErrNoLocalModelFound):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
|
||||||
case errors.Is(err, ErrNoRouterFound):
|
|
||||||
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
|
||||||
default:
|
|
||||||
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
|
||||||
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
|
||||||
// Check Accept header for preferred response format
|
|
||||||
acceptHeader := r.Header.Get("Accept")
|
|
||||||
if strings.Contains(acceptHeader, "text/plain") {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(acceptHeader, "text/html") {
|
|
||||||
w.Header().Set("Content-Type", "text/html")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,17 +11,17 @@ package scheduler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/process"
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrModelNotFound is granted to callers whose model is not handled by this
|
// ErrModelNotFound is granted to callers whose model is not handled by this
|
||||||
// router. The router package aliases it so SendError can match it.
|
// router. It is an alias for shared.ErrNoLocalModelFound.
|
||||||
var ErrModelNotFound = fmt.Errorf("local model not found")
|
var ErrModelNotFound = shared.ErrNoLocalModelFound
|
||||||
|
|
||||||
// Swapper is the eviction policy: it decides which running models must be
|
// Swapper is the eviction policy: it decides which running models must be
|
||||||
// stopped before a target can serve. It is orthogonal to the scheduling
|
// stopped before a target can serve. It is orthogonal to the scheduling
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -163,7 +162,7 @@ func (s *Server) startPreload() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||||
|
|
||||||
dw := &discardResponseWriter{status: http.StatusOK}
|
dw := &discardResponseWriter{status: http.StatusOK}
|
||||||
s.local.ServeHTTP(dw, req)
|
s.local.ServeHTTP(dw, req)
|
||||||
@@ -208,7 +207,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
||||||
if !found {
|
if !found {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +229,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Strip the /upstream/<model> prefix before forwarding.
|
// Strip the /upstream/<model> prefix before forwarding.
|
||||||
r.URL.Path = remainingPath
|
r.URL.Path = remainingPath
|
||||||
// Pin the resolved model so the router skips body/query extraction.
|
// Pin the resolved model so the router skips body/query extraction.
|
||||||
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case s.local.Handles(modelID):
|
case s.local.Handles(modelID):
|
||||||
@@ -238,7 +237,7 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
case s.peer.Handles(modelID):
|
case s.peer.Handles(modelID):
|
||||||
s.peer.ServeHTTP(w, r)
|
s.peer.ServeHTTP(w, r)
|
||||||
default:
|
default:
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -76,11 +75,11 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
|||||||
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||||
realName, found := s.cfg.RealModelName(requested)
|
realName, found := s.cfg.RealModelName(requested)
|
||||||
if !found {
|
if !found {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !s.local.Handles(realName) {
|
if !s.local.Handles(realName) {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
shared.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.local.Unload(apiUnloadTimeout, realName)
|
s.local.Unload(apiUnloadTimeout, realName)
|
||||||
@@ -92,7 +91,7 @@ func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
data, err := s.metrics.getMetricsJSON()
|
data, err := s.metrics.getMetricsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -103,7 +102,7 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
|||||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.perf == nil {
|
if s.perf == nil {
|
||||||
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +111,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
|||||||
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||||
after, err := time.Parse(time.RFC3339, afterStr)
|
after, err := time.Parse(time.RFC3339, afterStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
shared.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||||
@@ -153,19 +152,19 @@ func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||||
id, err := strconv.Atoi(r.PathValue("id"))
|
id, err := strconv.Atoi(r.PathValue("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
shared.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
capture := s.metrics.getCaptureByID(id)
|
capture := s.metrics.getCaptureByID(id)
|
||||||
if capture == nil {
|
if capture == nil {
|
||||||
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
shared.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(capture)
|
jsonBytes, err := json.Marshal(capture)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -198,7 +197,7 @@ func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+17
-31
@@ -1,19 +1,17 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateAuthMiddleware returns middleware that validates API keys when the
|
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||||
// config declares any. It accepts the key via Authorization: Bearer,
|
// config declares any. It accepts the key via Authorization: Bearer,
|
||||||
// Authorization: Basic (password field), or x-api-key. On success the auth
|
// Authorization: Basic (password field), or x-api-key. When no keys are
|
||||||
// headers are stripped so they never leak to upstream. When no keys are
|
|
||||||
// configured the middleware is a pass-through.
|
// configured the middleware is a pass-through.
|
||||||
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||||
keys := cfg.RequiredAPIKeys
|
keys := cfg.RequiredAPIKeys
|
||||||
@@ -22,7 +20,7 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
provided := extractAPIKey(r)
|
provided := shared.ExtractAPIKey(r)
|
||||||
|
|
||||||
valid := false
|
valid := false
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
@@ -33,41 +31,29 @@ func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
}
|
}
|
||||||
if !valid {
|
if !valid {
|
||||||
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||||
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
shared.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Header.Del("Authorization")
|
|
||||||
r.Header.Del("x-api-key")
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
// CreateRequestContextMiddleware returns middleware that extracts model and
|
||||||
// then Bearer, then x-api-key.
|
// auth info from the request into the context. Requests where no model can be
|
||||||
func extractAPIKey(r *http.Request) string {
|
// identified are rejected with a 404.
|
||||||
var bearerKey, basicKey string
|
func CreateRequestContextMiddleware(cfg config.Config) chain.Middleware {
|
||||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
return func(next http.Handler) http.Handler {
|
||||||
if strings.HasPrefix(auth, "Bearer ") {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
data, err := shared.FetchContext(r, cfg)
|
||||||
} else if strings.HasPrefix(auth, "Basic ") {
|
if err != nil {
|
||||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
return
|
||||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
|
||||||
basicKey = parts[1] // password field is the API key
|
|
||||||
}
|
}
|
||||||
}
|
_ = data
|
||||||
}
|
next.ServeHTTP(w, r)
|
||||||
}
|
})
|
||||||
|
|
||||||
switch {
|
|
||||||
case basicKey != "":
|
|
||||||
return basicKey
|
|
||||||
case bearerKey != "":
|
|
||||||
return bearerKey
|
|
||||||
default:
|
|
||||||
return r.Header.Get("x-api-key")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +1,14 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
|
||||||
basicHeader := func(user, pass string) string {
|
|
||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
|
||||||
}
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
auth string
|
|
||||||
xapi string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"none", "", "", ""},
|
|
||||||
{"bearer", "Bearer tok123", "", "tok123"},
|
|
||||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
|
||||||
{"x-api-key", "", "xkey", "xkey"},
|
|
||||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
|
||||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
|
||||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
|
||||||
}
|
|
||||||
for _, c := range cases {
|
|
||||||
t.Run(c.name, func(t *testing.T) {
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
if c.auth != "" {
|
|
||||||
r.Header.Set("Authorization", c.auth)
|
|
||||||
}
|
|
||||||
if c.xapi != "" {
|
|
||||||
r.Header.Set("x-api-key", c.xapi)
|
|
||||||
}
|
|
||||||
if got := extractAPIKey(r); got != c.want {
|
|
||||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
in string
|
in string
|
||||||
@@ -74,11 +40,42 @@ func TestServer_IsTokenChar(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_RequestContextMiddleware(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"llama3": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := CreateRequestContextMiddleware(cfg)
|
||||||
|
|
||||||
|
t.Run("known model passes through", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"llama3"}`))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model returns 404", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{}`))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mw(final).ServeHTTP(w, r)
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want 404", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_AuthMiddleware(t *testing.T) {
|
func TestServer_AuthMiddleware(t *testing.T) {
|
||||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
|
||||||
t.Error("auth headers leaked to upstream")
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||||
@@ -32,9 +32,9 @@ func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
data, err := router.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func concurrencyTestReq(model string) *http.Request {
|
func concurrencyTestReq(model string) *http.Request {
|
||||||
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
||||||
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,9 +34,9 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := router.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,13 +48,13 @@ func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
shared.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err = applyFilters(body, data.Model, useModelName, filters)
|
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,9 +84,9 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := router.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,13 +97,13 @@ func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
shared.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
shared.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||||
@@ -102,13 +102,13 @@ func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
logger, err := s.getLogger(logMonitorID)
|
logger, err := s.getLogger(logMonitorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
shared.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
shared.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateMetricsMiddleware returns middleware that records token metrics for
|
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||||
@@ -23,9 +23,9 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
|
|||||||
|
|
||||||
// Resolve the model now so downstream dispatch hits the context
|
// Resolve the model now so downstream dispatch hits the context
|
||||||
// fast path; FetchContext restores the request body.
|
// fast path; FetchContext restores the request body.
|
||||||
data, err := router.FetchContext(r, cfg)
|
data, err := shared.FetchContext(r, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+12
-18
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||||
"github.com/mostlygeek/llama-swap/internal/router"
|
"github.com/mostlygeek/llama-swap/internal/router"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||||
@@ -138,13 +139,13 @@ func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, up
|
|||||||
}
|
}
|
||||||
|
|
||||||
// localPeerHandler dispatches a model-routed request to the local or peer
|
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||||
// router. The model is resolved once via router.FetchContext.
|
// router. The model is resolved once via shared.FetchContext.
|
||||||
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
stripVersionPrefix(r)
|
stripVersionPrefix(r)
|
||||||
|
|
||||||
data, err := router.FetchContext(r, s.cfg)
|
data, err := shared.FetchContext(r, s.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
router.SendError(w, r, router.ErrNoModelInContext)
|
shared.SendError(w, r, shared.ErrNoModelInContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +157,7 @@ func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||||
s.peer.ServeHTTP(w, r)
|
s.peer.ServeHTTP(w, r)
|
||||||
default:
|
default:
|
||||||
router.SendError(w, r, router.ErrNoRouterFound)
|
shared.SendError(w, r, router.ErrNoRouterFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,21 +172,14 @@ func stripVersionPrefix(r *http.Request) {
|
|||||||
// routes builds the mux, registers every route, and wraps the mux with the
|
// routes builds the mux, registers every route, and wraps the mux with the
|
||||||
// global CORS middleware.
|
// global CORS middleware.
|
||||||
func (s *Server) routes() {
|
func (s *Server) routes() {
|
||||||
authMW := CreateAuthMiddleware(s.cfg)
|
|
||||||
filterMW := CreateFilterMiddleware(s.cfg)
|
|
||||||
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
|
||||||
|
|
||||||
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
authMW := CreateAuthMiddleware(s.cfg)
|
||||||
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
|
||||||
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
|
||||||
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
|
||||||
// other's Content-Type. Both run before the metrics middleware so it buffers
|
|
||||||
// the rewritten body.
|
|
||||||
modelChain := chain.New(
|
modelChain := chain.New(
|
||||||
authMW,
|
authMW,
|
||||||
|
CreateRequestContextMiddleware(s.cfg),
|
||||||
CreateConcurrencyMiddleware(s.cfg),
|
CreateConcurrencyMiddleware(s.cfg),
|
||||||
filterMW,
|
CreateFilterMiddleware(s.cfg),
|
||||||
formFilterMW,
|
CreateFormFilterMiddleware(s.cfg),
|
||||||
CreateInflightMiddleware(s.inflight),
|
CreateInflightMiddleware(s.inflight),
|
||||||
CreateMetricsMiddleware(s.metrics, s.cfg),
|
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||||
)
|
)
|
||||||
@@ -216,11 +210,11 @@ func (s *Server) routes() {
|
|||||||
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||||
|
|
||||||
// Embedded UI.
|
// Embedded UI.
|
||||||
mux.HandleFunc("GET /ui/", s.handleUI)
|
mux.Handle("GET /ui/", chain.New(authMW).ThenFunc(s.handleUI))
|
||||||
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||||
|
|
||||||
// Prometheus metrics (no auth, matches the legacy endpoint).
|
// Prometheus metrics (wrapped by apiChain, matches the legacy endpoint).
|
||||||
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
mux.Handle("GET /metrics", apiChain.ThenFunc(s.handleMetrics))
|
||||||
|
|
||||||
// Operations endpoints.
|
// Operations endpoints.
|
||||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||||
|
|||||||
@@ -0,0 +1,202 @@
|
|||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextkey struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReqContextData struct {
|
||||||
|
ApiKey string
|
||||||
|
Model string
|
||||||
|
ModelID string
|
||||||
|
Streaming bool
|
||||||
|
SendLoadingState bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ReqContextKey = &contextkey{"context"}
|
||||||
|
ErrNoModelInContext = fmt.Errorf("no model in request context")
|
||||||
|
ErrNoRouterFound = fmt.Errorf("no router found for model")
|
||||||
|
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
|
||||||
|
ErrNoLocalModelFound = fmt.Errorf("local model not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
func SendError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrNoModelInContext):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
|
||||||
|
case errors.Is(err, ErrNoPeerModelFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
|
||||||
|
case errors.Is(err, ErrNoLocalModelFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||||
|
case errors.Is(err, ErrNoRouterFound):
|
||||||
|
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
|
||||||
|
default:
|
||||||
|
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendResponse detects what content type the client prefers and returns an error response in that format.
|
||||||
|
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||||
|
acceptHeader := r.Header.Get("Accept")
|
||||||
|
if strings.Contains(acceptHeader, "text/plain") {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acceptHeader, "text/html") {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, html.EscapeString(message))))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
resp, err := json.Marshal(map[string]string{"src": "llama-swap", "error": message})
|
||||||
|
if err != nil {
|
||||||
|
w.Write([]byte(`{"src":"llama-swap", "error": "failed to marshal response"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchContext will attempt to get the model id from the context then
|
||||||
|
// from the model body. If it extracts the model from the body it will
|
||||||
|
// store the model in the context for downstream handlers. An error
|
||||||
|
// will be returned when model can not be fetch from either location.
|
||||||
|
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
|
||||||
|
data, ok := ReadContext(r.Context())
|
||||||
|
if ok {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, err := extractContext(r); err == nil && data.Model != "" {
|
||||||
|
realName, _ := cfg.RealModelName(data.Model)
|
||||||
|
if realName == "" {
|
||||||
|
realName = data.Model
|
||||||
|
}
|
||||||
|
data.ModelID = realName
|
||||||
|
if mc, ok := cfg.Models[realName]; ok {
|
||||||
|
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
|
||||||
|
}
|
||||||
|
*r = *r.WithContext(SetContext(r.Context(), data))
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReqContextData{}, ErrNoModelInContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetContext(ctx context.Context, data ReqContextData) context.Context {
|
||||||
|
return context.WithValue(ctx, ReqContextKey, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadContext(ctx context.Context) (ReqContextData, bool) {
|
||||||
|
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||||
|
return data, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractContext pulls fields from an HTTP request into a ReqContextData,
|
||||||
|
// returning whatever is available. For GET requests it reads query parameters.
|
||||||
|
// For POST requests it inspects Content-Type and parses JSON,
|
||||||
|
// multipart/form-data, or application/x-www-form-urlencoded bodies. The
|
||||||
|
// request body is always restored before returning. An error is returned only
|
||||||
|
// for I/O or parse failures, not for missing fields.
|
||||||
|
func extractContext(r *http.Request) (ReqContextData, error) {
|
||||||
|
|
||||||
|
apiKey := ExtractAPIKey(r)
|
||||||
|
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
q := r.URL.Query()
|
||||||
|
return ReqContextData{
|
||||||
|
Model: q.Get("model"),
|
||||||
|
Streaming: q.Get("stream") == "true",
|
||||||
|
ApiKey: apiKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
}()
|
||||||
|
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
return ReqContextData{
|
||||||
|
Model: gjson.GetBytes(bodyBytes, "model").String(),
|
||||||
|
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||||
|
ApiKey: apiKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Form parsers read from r.Body, so feed them a fresh reader over the
|
||||||
|
// buffered bytes. The deferred restore above will reset r.Body again
|
||||||
|
// after parsing.
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
if strings.Contains(contentType, "multipart/form-data") {
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReqContextData{
|
||||||
|
Model: r.FormValue("model"),
|
||||||
|
Streaming: r.FormValue("stream") == "true",
|
||||||
|
ApiKey: apiKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||||
|
// then Bearer, then x-api-key.
|
||||||
|
func ExtractAPIKey(r *http.Request) string {
|
||||||
|
var bearerKey, basicKey string
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
scheme, credentials, ok := strings.Cut(auth, " ")
|
||||||
|
if ok {
|
||||||
|
switch strings.ToLower(scheme) {
|
||||||
|
case "bearer":
|
||||||
|
bearerKey = credentials
|
||||||
|
case "basic":
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(credentials); err == nil {
|
||||||
|
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||||
|
basicKey = parts[1] // password field is the API key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case basicKey != "":
|
||||||
|
return basicKey
|
||||||
|
case bearerKey != "":
|
||||||
|
return bearerKey
|
||||||
|
default:
|
||||||
|
return r.Header.Get("x-api-key")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
package router
|
package shared
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -20,13 +22,13 @@ func TestExtractContext_GET(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"model present", "model=llama3", "llama3", false},
|
{"model present", "model=llama3", "llama3", false},
|
||||||
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
|
||||||
{"model missing", "", "", true},
|
{"model missing", "", "", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||||
got, err := ExtractContext(r)
|
got, err := extractContext(r)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
}
|
}
|
||||||
@@ -46,16 +48,16 @@ func TestExtractContext_JSON(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
|
||||||
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
|
||||||
{"model empty string", `{"model":""}`, "", true},
|
{"model empty string", `{"model":""}`, "", false},
|
||||||
{"model key missing", `{"stream":true}`, "", true},
|
{"model key missing", `{"stream":true}`, "", false},
|
||||||
{"invalid json", `not-json`, "", true},
|
{"invalid json", `not-json`, "", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
got, err := ExtractContext(r)
|
got, err := extractContext(r)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
}
|
}
|
||||||
@@ -74,7 +76,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) {
|
|||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"model present", "whisper-1", "whisper-1", false},
|
{"model present", "whisper-1", "whisper-1", false},
|
||||||
{"model missing", "", "", true},
|
{"model missing", "", "", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -85,7 +87,7 @@ func TestExtractContext_URLEncodedForm(t *testing.T) {
|
|||||||
}
|
}
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
got, err := ExtractContext(r)
|
got, err := extractContext(r)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
}
|
}
|
||||||
@@ -104,7 +106,7 @@ func TestExtractContext_MultipartForm(t *testing.T) {
|
|||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"model present", "whisper-1", "whisper-1", false},
|
{"model present", "whisper-1", "whisper-1", false},
|
||||||
{"model missing", "", "", true},
|
{"model missing", "", "", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -119,7 +121,7 @@ func TestExtractContext_MultipartForm(t *testing.T) {
|
|||||||
|
|
||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
got, err := ExtractContext(r)
|
got, err := extractContext(r)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
|
||||||
}
|
}
|
||||||
@@ -135,7 +137,7 @@ func TestExtractContext_JSONBodyRestored(t *testing.T) {
|
|||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
if _, err := ExtractContext(r); err != nil {
|
if _, err := extractContext(r); err != nil {
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,7 +164,7 @@ func TestExtractContext_MultipartBodyRestored(t *testing.T) {
|
|||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
|
||||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
|
||||||
if _, err := ExtractContext(r); err != nil {
|
if _, err := extractContext(r); err != nil {
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,7 +182,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
|||||||
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
if _, err := ExtractContext(r); err != nil {
|
if _, err := extractContext(r); err != nil {
|
||||||
t.Fatalf("ExtractContext: %v", err)
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,7 +197,7 @@ func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
|
|||||||
|
|
||||||
func TestSetContext(t *testing.T) {
|
func TestSetContext(t *testing.T) {
|
||||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
data, ok := ctx.Value(ContextKey).(ReqContextData)
|
data, ok := ctx.Value(ReqContextKey).(ReqContextData)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("ContextKey not set or wrong type")
|
t.Fatalf("ContextKey not set or wrong type")
|
||||||
}
|
}
|
||||||
@@ -209,7 +211,7 @@ func TestSetContext(t *testing.T) {
|
|||||||
|
|
||||||
func TestSetContext_WithAlias(t *testing.T) {
|
func TestSetContext_WithAlias(t *testing.T) {
|
||||||
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
|
||||||
data, _ := ctx.Value(ContextKey).(ReqContextData)
|
data, _ := ctx.Value(ReqContextKey).(ReqContextData)
|
||||||
if data.Model != "llama" {
|
if data.Model != "llama" {
|
||||||
t.Errorf("want requested %q got %q", "llama", data.Model)
|
t.Errorf("want requested %q got %q", "llama", data.Model)
|
||||||
}
|
}
|
||||||
@@ -221,7 +223,7 @@ func TestSetContext_WithAlias(t *testing.T) {
|
|||||||
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
func TestSetContext_DoesNotMutateParent(t *testing.T) {
|
||||||
parent := context.Background()
|
parent := context.Background()
|
||||||
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
|
||||||
if v := parent.Value(ContextKey); v != nil {
|
if v := parent.Value(ReqContextKey); v != nil {
|
||||||
t.Errorf("parent context was mutated: %v", v)
|
t.Errorf("parent context was mutated: %v", v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -273,3 +275,152 @@ func TestReadContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_GET(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
wantStreaming bool
|
||||||
|
}{
|
||||||
|
{"streaming true", "model=llama3&stream=true", true},
|
||||||
|
{"streaming false", "model=llama3&stream=false", false},
|
||||||
|
{"no stream param", "model=llama3", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.Streaming != tt.wantStreaming {
|
||||||
|
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_JSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantStreaming bool
|
||||||
|
}{
|
||||||
|
{"streaming true", `{"model":"llama3","stream":true}`, true},
|
||||||
|
{"streaming false", `{"model":"llama3","stream":false}`, false},
|
||||||
|
{"no stream param", `{"model":"llama3"}`, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.Streaming != tt.wantStreaming {
|
||||||
|
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractContext: %v", err)
|
||||||
|
}
|
||||||
|
if !got.Streaming {
|
||||||
|
t.Error("Streaming should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractContext_ApiKey(t *testing.T) {
|
||||||
|
basicHeader := func(user, pass string) string {
|
||||||
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||||
|
}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
ct string
|
||||||
|
body string
|
||||||
|
auth string
|
||||||
|
xapi string
|
||||||
|
wantKey string
|
||||||
|
}{
|
||||||
|
{"GET bearer", http.MethodGet, "", "", "Bearer sk-get", "", "sk-get"},
|
||||||
|
{"GET x-api-key", http.MethodGet, "", "", "", "xk-get", "xk-get"},
|
||||||
|
{"GET basic", http.MethodGet, "", "", basicHeader("u", "pw-get"), "", "pw-get"},
|
||||||
|
{"JSON bearer", http.MethodPost, "application/json", `{"model":"m"}`, "Bearer sk-json", "", "sk-json"},
|
||||||
|
{"JSON x-api-key", http.MethodPost, "application/json", `{"model":"m"}`, "", "xk-json", "xk-json"},
|
||||||
|
{"form bearer", http.MethodPost, "application/x-www-form-urlencoded", "model=m", "Bearer sk-form", "", "sk-form"},
|
||||||
|
{"no key", http.MethodGet, "", "", "", "", ""},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
var body io.Reader
|
||||||
|
if c.body != "" {
|
||||||
|
body = strings.NewReader(c.body)
|
||||||
|
}
|
||||||
|
r, _ := http.NewRequest(c.method, "/", body)
|
||||||
|
if c.ct != "" {
|
||||||
|
r.Header.Set("Content-Type", c.ct)
|
||||||
|
}
|
||||||
|
if c.auth != "" {
|
||||||
|
r.Header.Set("Authorization", c.auth)
|
||||||
|
}
|
||||||
|
if c.xapi != "" {
|
||||||
|
r.Header.Set("x-api-key", c.xapi)
|
||||||
|
}
|
||||||
|
got, err := extractContext(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("extractContext: %v", err)
|
||||||
|
}
|
||||||
|
if got.ApiKey != c.wantKey {
|
||||||
|
t.Errorf("ApiKey = %q, want %q", got.ApiKey, c.wantKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||||
|
basicHeader := func(user, pass string) string {
|
||||||
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||||
|
}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth string
|
||||||
|
xapi string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"none", "", "", ""},
|
||||||
|
{"bearer", "Bearer tok123", "", "tok123"},
|
||||||
|
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||||
|
{"x-api-key", "", "xkey", "xkey"},
|
||||||
|
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||||
|
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||||
|
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||||
|
{"lowercase bearer", "bearer tok123", "", "tok123"},
|
||||||
|
{"lowercase basic", "basic " + base64.StdEncoding.EncodeToString([]byte("user:pw-key")), "", "pw-key"},
|
||||||
|
{"mixed case BEARER", "BEARER tok456", "", "tok456"},
|
||||||
|
{"mixed case bAsIc", "bAsIc " + base64.StdEncoding.EncodeToString([]byte("u:bk")), "", "bk"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if c.auth != "" {
|
||||||
|
r.Header.Set("Authorization", c.auth)
|
||||||
|
}
|
||||||
|
if c.xapi != "" {
|
||||||
|
r.Header.Set("x-api-key", c.xapi)
|
||||||
|
}
|
||||||
|
if got := ExtractAPIKey(r); got != c.want {
|
||||||
|
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user