Introduce new routing backend (#790)

This is a huge backend change that essentially started with rewriting
the concurrency handling for processes and blew up to a refactor of the
entire application. In short these are the improvements:

**Better state and life cycle management:** 

Life cycle management of processes has always been the trickiest part of
the code. Juggling mutex locks between multiple locations to reduce race
conditions was complex. Too complex for my feeble brain to build a
simple mental model around as llama-swap gained more features. All of
that has been refactored. Most of the locks are gone, replaced with a
single run() that owns all state changes. There is one place to start
from now to understand and extend routing logic.

The improved life cycle management makes it easier to implement more
complex swap optimization strategies in the future like #727.

**Collation of requests:**

llama-swap previously handled requests and swapping in the order they
came in. For example requests for models in this order ABCABC would
result in 5 swaps. Now those requests are handled in this order AABBCC.
The result is less time waiting for swap under a high churn request
queue. This fixes #588 #612.

A possible future enhancement is to support a starvation parameter so
swap can be forced when models have been waiting too long.

**Shared base implementation for groups and swap matrix:** 

During the refactor it became clear that much of the swapping logic was
shared between these two implementations. That is not surprising
considering the swap matrix was added many moons after groups. Now they
share a common base and their specific swap strategies are implemented
into the swapPlanner interface.

Requests for bespoke or specific swapping scenarios is a common theme in
the issues. Now users can implement whatever bespoke and weird swapping
strategy they want in their own fork. Just ask your agent of choice to
implement swapPlanner. I'll still remaining more conservative on what
actually lands in core llama-swap and will continue to evaluate PRs if
the changes is good for everyone or just one specific use case.

**AI / Agentic Disclosure:** 

I paid very close attention to the low level swap concurrency design and
implementation. It's important to keep that essential part reliable,
boring and no surprises. Backwards compatibility was also maintained,
even the one way non-exclusive group model loading behaviour that people
have rightly pointed out be a weird design decision.

With the underlying swap core done the web server, api and UI sitting on
top were largely ported over with Claude Code and Opus 4.7 in multiple
phases. If you're curious I kept the changes in docs/newrouter-todo.md.
I did several passes to make sure things weren't left behind.

However, even frontier LLMs at the time of this PR still make small
decisions that don't make a lot of sense. They get shit wrong all the
time, just in small subtle way.

That said, there's likely to be some new bugs introduced with this
massive refactor. I'm fairly confident that there's no major
architectural flaws that would cause goal seeking agents to make dumb,
ugly code decisions.

For a little while the legacy llama-swap will be available under
cmd/legacy/llama-swap. The plan is to eventually delete that entry point
as well as the proxy package.

On a bit of a personal note, this PR is exciting and a bit sad for me. I
hand wrote much of the original code and this PR ultimately replaces
much of it. While the old code served as a good reference for the agent
to implement the new stuff it still a bit sad to eventually delete it
all.
This commit is contained in:
Benson Wong
2026-05-28 21:47:01 -07:00
committed by GitHub
parent 63bc266395
commit 02e015fa49
107 changed files with 12014 additions and 251 deletions
+266
View File
@@ -0,0 +1,266 @@
package server
import (
"encoding/json"
"net/http"
"sort"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/router"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
type modelRecord struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
}
// handleListModels serves the OpenAI-compatible model listing: local models
// (with optional aliases) plus peer models.
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
created := time.Now().Unix()
data := make([]modelRecord, 0, len(s.cfg.Models))
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
rec := modelRecord{
ID: id,
Object: "model",
Created: created,
OwnedBy: "llama-swap",
Name: strings.TrimSpace(name),
Description: strings.TrimSpace(description),
}
if len(metadata) > 0 {
rec.Meta = map[string]any{"llamaswap": metadata}
}
return rec
}
for id, mc := range s.cfg.Models {
if mc.Unlisted {
continue
}
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
if s.cfg.IncludeAliasesInList {
for _, alias := range mc.Aliases {
if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
}
}
}
}
for peerID, peer := range s.cfg.Peers {
for _, modelID := range peer.Models {
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
}
}
sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID })
// Echo the Origin so browser clients can read the listing.
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"object": "list",
"data": data,
})
}
// runningModel is one entry in the /running listing.
type runningModel struct {
Model string `json:"model"`
State string `json:"state"`
Cmd string `json:"cmd"`
Proxy string `json:"proxy"`
TTL int `json:"ttl"`
Name string `json:"name"`
Description string `json:"description"`
}
// handleUnload stops every running local process. Peer models are remote and
// unaffected.
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
s.local.Unload(0)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
// handleRunning lists local processes that are not stopped, joining each model
// ID against its config for the cmd/proxy/ttl/name/description metadata.
func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) {
states := s.local.RunningModels()
list := make([]runningModel, 0, len(states))
for id, state := range states {
mc := s.cfg.Models[id]
list = append(list, runningModel{
Model: id,
State: string(state),
Cmd: mc.Cmd,
Proxy: mc.Proxy,
TTL: mc.UnloadAfter,
Name: mc.Name,
Description: mc.Description,
})
}
sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model })
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{"running": list})
}
// discardResponseWriter satisfies http.ResponseWriter for preload requests,
// dropping the body while capturing the status code.
type discardResponseWriter struct {
header http.Header
status int
}
func (d *discardResponseWriter) Header() http.Header {
if d.header == nil {
d.header = make(http.Header)
}
return d.header
}
func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
func (d *discardResponseWriter) WriteHeader(status int) { d.status = status }
// startPreload fires a background GET / at every model named in
// Hooks.OnStartup.Preload so they are warm before the first real request.
// Preload names are already resolved to real model IDs by config loading.
func (s *Server) startPreload() {
models := s.cfg.Hooks.OnStartup.Preload
if len(models) == 0 {
return
}
go func() {
for _, modelID := range models {
if !s.local.Handles(modelID) {
s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID)
continue
}
s.proxylog.Infof("preloading model: %s", modelID)
req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil)
if err != nil {
continue
}
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
dw := &discardResponseWriter{status: http.StatusOK}
s.local.ServeHTTP(dw, req)
success := dw.status < http.StatusBadRequest
if !success {
s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status)
}
event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success})
}
}()
}
// handleMetrics serves Prometheus-format performance metrics. Returns 503 when
// performance monitoring is disabled.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
if s.perf == nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte("# performance monitor not available\n"))
return
}
s.perf.MetricsHandler().ServeHTTP(w, r)
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
func handleRootRedirect(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ui", http.StatusFound)
}
func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ui/models", http.StatusFound)
}
// handleUpstream proxies ANY request under /upstream/<model>/<path> directly to
// the model's process, bypassing model dispatch by body/query inspection.
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
upstreamPath := r.PathValue("upstreamPath")
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
if !found {
router.SendResponse(w, r, http.StatusNotFound, "model not found")
return
}
// Redirect /upstream/model to /upstream/model/ so relative URLs in upstream
// responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method.
if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") {
newPath := "/upstream/" + searchName + "/"
if r.URL.RawQuery != "" {
newPath += "?" + r.URL.RawQuery
}
if r.Method == http.MethodGet || r.Method == http.MethodHead {
http.Redirect(w, r, newPath, http.StatusMovedPermanently)
} else {
http.Redirect(w, r, newPath, http.StatusPermanentRedirect)
}
return
}
// Strip the /upstream/<model> prefix before forwarding.
r.URL.Path = remainingPath
// Pin the resolved model so the router skips body/query extraction.
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
switch {
case s.local.Handles(modelID):
s.local.ServeHTTP(w, r)
case s.peer.Handles(modelID):
s.peer.ServeHTTP(w, r)
default:
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
}
}
// findModelInPath walks a slash-separated path, building up segments until one
// matches a configured model. This resolves model names that contain slashes
// (e.g. "author/model"). Returns the matched name, its real model ID, the
// remaining path, and whether a match was found.
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
name := ""
for i, part := range parts {
if part == "" {
continue
}
if name == "" {
name = part
} else {
name = name + "/" + part
}
if modelID, ok := cfg.RealModelName(name); ok {
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
+159
View File
@@ -0,0 +1,159 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
)
func TestServer_HandleListModels(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.cfg = config.Config{
Models: map[string]config.ModelConfig{
"visible": {Name: "Visible", Description: "a model"},
"hidden": {Unlisted: true},
},
Peers: config.PeerDictionaryConfig{
"peer1": {Models: []string{"remote-model"}},
},
}
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
req.Header.Set("Origin", "http://example.com")
s.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Errorf("Access-Control-Allow-Origin = %q", got)
}
var resp struct {
Data []modelRecord `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode: %v", err)
}
ids := map[string]bool{}
for _, m := range resp.Data {
ids[m.ID] = true
}
if !ids["visible"] || !ids["remote-model"] {
t.Errorf("missing expected models: %v", ids)
}
if ids["hidden"] {
t.Error("unlisted model should not appear")
}
}
func TestServer_HandleListModels_Aliases(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.cfg = config.Config{
IncludeAliasesInList: true,
Models: map[string]config.ModelConfig{
"real": {Aliases: []string{"nick"}},
},
}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
var resp struct {
Data []modelRecord `json:"data"`
}
json.Unmarshal(w.Body.Bytes(), &resp)
ids := map[string]bool{}
for _, m := range resp.Data {
ids[m.ID] = true
}
if !ids["real"] || !ids["nick"] {
t.Errorf("expected alias entry; got %v", ids)
}
}
func TestServer_FindModelInPath(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{
"author/model": {},
"simple": {},
}}
cases := []struct {
path string
wantName string
wantRem string
wantFound bool
}{
{"/simple/v1/chat", "simple", "/v1/chat", true},
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
{"/author/model", "author/model", "/", true},
{"/missing/v1", "", "", false},
{"/", "", "", false},
}
for _, c := range cases {
name, _, rem, found := findModelInPath(cfg, c.path)
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
}
}
}
func TestServer_HandleUpstream(t *testing.T) {
local := newStubRouter([]string{"m1"}, "upstream-body")
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
t.Run("proxies to local", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat", nil))
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
t.Errorf("status=%d body=%q", w.Code, w.Body.String())
}
})
t.Run("redirects bare model path", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1", nil))
if w.Code != http.StatusMovedPermanently {
t.Errorf("status = %d, want 301", w.Code)
}
})
t.Run("unknown model 404", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/nope/v1", nil))
if w.Code != http.StatusNotFound {
t.Errorf("status = %d, want 404", w.Code)
}
})
}
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/metrics", nil))
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want 503", w.Code)
}
}
func TestServer_Redirects(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
for path, want := range map[string]string{"/": "/ui", "/upstream": "/ui/models"} {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
if w.Code != http.StatusFound {
t.Errorf("%s: status = %d, want 302", path, w.Code)
}
if got := w.Header().Get("Location"); got != want {
t.Errorf("%s: Location = %q, want %q", path, got, want)
}
}
}
+270
View File
@@ -0,0 +1,270 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/perf"
"github.com/mostlygeek/llama-swap/internal/router"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// apiModel is one entry in the /api/events modelStatus payload.
type apiModel struct {
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"`
Aliases []string `json:"aliases,omitempty"`
}
// modelStatus returns every configured model joined with its current process
// state (defaulting to "stopped"), followed by peer models.
func (s *Server) modelStatus() []apiModel {
running := s.local.RunningModels()
ids := make([]string, 0, len(s.cfg.Models))
for id := range s.cfg.Models {
ids = append(ids, id)
}
sort.Strings(ids)
models := make([]apiModel, 0, len(ids))
for _, id := range ids {
mc := s.cfg.Models[id]
state := "stopped"
if st, ok := running[id]; ok {
state = string(st)
}
models = append(models, apiModel{
Id: id,
Name: mc.Name,
Description: mc.Description,
State: state,
Unlisted: mc.Unlisted,
Aliases: mc.Aliases,
})
}
for peerID, peer := range s.cfg.Peers {
for _, modelID := range peer.Models {
models = append(models, apiModel{Id: modelID, PeerID: peerID})
}
}
return models
}
// handleAPIUnloadAll stops every running local process.
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
s.local.Unload(0)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
}
// handleAPIUnloadModel stops a single named local process.
func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
requested := strings.TrimPrefix(r.PathValue("model"), "/")
realName, found := s.cfg.RealModelName(requested)
if !found {
router.SendResponse(w, r, http.StatusNotFound, "model not found")
return
}
if !s.local.Handles(realName) {
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
return
}
s.local.Unload(0, realName)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
// handleAPIMetrics serves the activity log as a JSON array.
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
data, err := s.metrics.getMetricsJSON()
if err != nil {
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(data)
}
// handleAPIPerformance serves the buffered system/GPU stats, optionally
// filtered to samples after the ?after=<RFC3339> timestamp.
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
if s.perf == nil {
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
return
}
sysStats, gpuStats := s.perf.Current()
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
after, err := time.Parse(time.RFC3339, afterStr)
if err != nil {
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
return
}
filteredSys := make([]perf.SysStat, 0, len(sysStats))
for _, st := range sysStats {
if st.Timestamp.After(after) {
filteredSys = append(filteredSys, st)
}
}
sysStats = filteredSys
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
for _, g := range gpuStats {
if g.Timestamp.After(after) {
filteredGpu = append(filteredGpu, g)
}
}
gpuStats = filteredGpu
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"sys_stats": sysStats,
"gpu_stats": gpuStats,
})
}
// handleAPIVersion serves the build metadata.
func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"version": s.build.Version,
"commit": s.build.Commit,
"build_date": s.build.Date,
})
}
// handleAPICapture returns the stored request/response capture for a metric ID.
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
id, err := strconv.Atoi(r.PathValue("id"))
if err != nil {
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
return
}
capture := s.metrics.getCaptureByID(id)
if capture == nil {
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
return
}
jsonBytes, err := json.Marshal(capture)
if err != nil {
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonBytes)
}
type messageType string
const (
msgTypeModelStatus messageType = "modelStatus"
msgTypeLogData messageType = "logData"
msgTypeMetrics messageType = "metrics"
msgTypeInFlight messageType = "inflight"
)
type messageEnvelope struct {
Type messageType `json:"type"`
Data string `json:"data"`
}
// handleAPIEvents streams server events (model status, log data, metrics,
// in-flight counts) to the client as Server-Sent Events.
func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Content-Type-Options", "nosniff")
// prevent nginx from buffering SSE
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
return
}
// internal/event already has a 50K event buffer
// a 1K message buffer should be enough, watch the logs for the warning that the sendBuffer is full
sendBuffer := make(chan messageEnvelope, 1024)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
send := func(msg messageEnvelope) {
select {
case sendBuffer <- msg:
case <-ctx.Done():
s.proxylog.Warn("handleAPIEvents send suppressed due to context done")
default:
s.proxylog.Warn("handleAPIEvents sendBuffer full, dropped message")
}
}
sendModels := func() {
if data, err := json.Marshal(s.modelStatus()); err == nil {
send(messageEnvelope{Type: msgTypeModelStatus, Data: string(data)})
}
}
sendLogData := func(source string, data []byte) {
if j, err := json.Marshal(map[string]string{"source": source, "data": string(data)}); err == nil {
send(messageEnvelope{Type: msgTypeLogData, Data: string(j)})
}
}
sendMetrics := func(metrics []ActivityLogEntry) {
if j, err := json.Marshal(metrics); err == nil {
send(messageEnvelope{Type: msgTypeMetrics, Data: string(j)})
}
}
sendInFlight := func(total int) {
if j, err := json.Marshal(map[string]int{"total": total}); err == nil {
send(messageEnvelope{Type: msgTypeInFlight, Data: string(j)})
}
}
defer event.On(func(e shared.ProcessStateChangeEvent) { sendModels() })()
defer event.On(func(e shared.ConfigFileChangedEvent) { sendModels() })()
defer s.proxylog.OnLogData(func(data []byte) { sendLogData("proxy", data) })()
defer s.upstreamlog.OnLogData(func(data []byte) { sendLogData("upstream", data) })()
defer event.On(func(e ActivityLogEvent) { sendMetrics([]ActivityLogEntry{e.Metrics}) })()
defer event.On(func(e shared.InFlightRequestsEvent) { sendInFlight(e.Total) })()
// initial payload
sendLogData("proxy", s.proxylog.GetHistory())
sendLogData("upstream", s.upstreamlog.GetHistory())
sendModels()
sendMetrics(s.metrics.getMetrics())
sendInFlight(int(s.inflight.Current()))
for {
select {
case <-r.Context().Done():
return
case <-s.shutdownCtx.Done():
return
case msg := <-sendBuffer:
data, err := json.Marshal(msg)
if err != nil {
continue
}
fmt.Fprintf(w, "event:message\ndata:%s\n\n", data)
flusher.Flush()
}
}
}
+103
View File
@@ -0,0 +1,103 @@
package server
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestServer_InflightMiddleware(t *testing.T) {
c := &inflightCounter{}
mw := CreateInflightMiddleware(c)
var duringRequest int64
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
duringRequest = c.Current()
}))
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil))
if duringRequest != 1 {
t.Errorf("counter during request = %d, want 1", duringRequest)
}
if got := c.Current(); got != 0 {
t.Errorf("counter after request = %d, want 0", got)
}
}
func TestServer_APIVersion(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.build = BuildInfo{Version: "1.2.3", Commit: "deadbeef", Date: "2026-05-19"}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/version", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
var got map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got["version"] != "1.2.3" || got["commit"] != "deadbeef" || got["build_date"] != "2026-05-19" {
t.Errorf("body = %v", got)
}
}
func TestServer_APIMetrics_Empty(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/metrics", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if body := strings.TrimSpace(w.Body.String()); body != "[]" {
t.Errorf("body = %q, want []", body)
}
}
func TestServer_APIPerformance_Unavailable(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/performance", nil))
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want 503", w.Code)
}
}
func TestServer_APIEvents_InitialPayload(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, "/api/events", nil).WithContext(ctx)
w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
s.ServeHTTP(w, req)
close(done)
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("handler did not return after context cancel")
}
body := w.Body.String()
for _, want := range []string{`"type":"modelStatus"`, `"type":"inflight"`, `"type":"logData"`} {
if !strings.Contains(body, want) {
t.Errorf("initial SSE payload missing %s; body=%q", want, body)
}
}
}
+135
View File
@@ -0,0 +1,135 @@
package server
import (
"encoding/base64"
"net/http"
"strings"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/router"
)
// CreateAuthMiddleware returns middleware that validates API keys when the
// config declares any. It accepts the key via Authorization: Bearer,
// Authorization: Basic (password field), or x-api-key. On success the auth
// headers are stripped so they never leak to upstream. When no keys are
// configured the middleware is a pass-through.
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
keys := cfg.RequiredAPIKeys
return func(next http.Handler) http.Handler {
if len(keys) == 0 {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
provided := extractAPIKey(r)
valid := false
for _, key := range keys {
if provided == key {
valid = true
break
}
}
if !valid {
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
return
}
r.Header.Del("Authorization")
r.Header.Del("x-api-key")
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
// then Bearer, then x-api-key.
func extractAPIKey(r *http.Request) string {
var bearerKey, basicKey string
if auth := r.Header.Get("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ")
} else if strings.HasPrefix(auth, "Basic ") {
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
basicKey = parts[1] // password field is the API key
}
}
}
}
switch {
case basicKey != "":
return basicKey
case bearerKey != "":
return bearerKey
default:
return r.Header.Get("x-api-key")
}
}
// CreateCORSMiddleware returns middleware that answers OPTIONS preflight
// requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS
// requests pass through untouched.
func CreateCORSMiddleware() chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodOptions {
next.ServeHTTP(w, r)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers))
} else {
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With")
}
w.Header().Set("Access-Control-Max-Age", "86400")
w.WriteHeader(http.StatusNoContent)
})
}
}
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
// sanitizeAccessControlRequestHeaderValues drops any header names that contain
// characters outside the HTTP token grammar before echoing them back.
func sanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+120
View File
@@ -0,0 +1,120 @@
package server
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
)
func TestServer_ExtractAPIKey(t *testing.T) {
basicHeader := func(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
}
cases := []struct {
name string
auth string
xapi string
want string
}{
{"none", "", "", ""},
{"bearer", "Bearer tok123", "", "tok123"},
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
{"x-api-key", "", "xkey", "xkey"},
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if c.auth != "" {
r.Header.Set("Authorization", c.auth)
}
if c.xapi != "" {
r.Header.Set("x-api-key", c.xapi)
}
if got := extractAPIKey(r); got != c.want {
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
}
})
}
}
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
cases := []struct {
in string
want string
}{
{"Content-Type, Authorization", "Content-Type, Authorization"},
{" X-Custom , Accept ", "X-Custom, Accept"},
{"Valid, Bad Header", "Valid"},
{"Bad@Header", ""},
{"", ""},
}
for _, c := range cases {
if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want {
t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want)
}
}
}
func TestServer_IsTokenChar(t *testing.T) {
for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" {
if !isTokenChar(r) {
t.Errorf("isTokenChar(%q) = false, want true", r)
}
}
for _, r := range " @()/\t\"" {
if isTokenChar(r) {
t.Errorf("isTokenChar(%q) = true, want false", r)
}
}
}
func TestServer_AuthMiddleware(t *testing.T) {
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
t.Error("auth headers leaked to upstream")
}
w.WriteHeader(http.StatusOK)
})
t.Run("no keys configured passes through", func(t *testing.T) {
mw := CreateAuthMiddleware(config.Config{})
w := httptest.NewRecorder()
mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
cfg := config.Config{RequiredAPIKeys: []string{"secret"}}
t.Run("valid key", func(t *testing.T) {
mw := CreateAuthMiddleware(cfg)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Authorization", "Bearer secret")
w := httptest.NewRecorder()
mw(final).ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
t.Run("invalid key", func(t *testing.T) {
mw := CreateAuthMiddleware(cfg)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Authorization", "Bearer wrong")
w := httptest.NewRecorder()
mw(final).ServeHTTP(w, r)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", w.Code)
}
if w.Header().Get("WWW-Authenticate") == "" {
t.Error("missing WWW-Authenticate header")
}
})
}
+176
View File
@@ -0,0 +1,176 @@
package server
import (
"fmt"
"net/http"
"strings"
"sync"
"github.com/fxamacker/cbor/v2"
"github.com/klauspost/compress/zstd"
)
// ReqRespCapture is a stored request/response pair for a single metered request.
type ReqRespCapture struct {
ID int `json:"id"`
ReqPath string `json:"req_path"`
ReqHeaders map[string]string `json:"req_headers"`
ReqBody []byte `json:"req_body"`
RespHeaders map[string]string `json:"resp_headers"`
RespBody []byte `json:"resp_body"`
}
// captureFields is a bitmask controlling what a route stores in a ReqRespCapture.
type captureFields uint
const (
captureReqHeaders captureFields = 1 << iota
captureReqBody
captureRespHeaders
captureRespBody
)
const (
captureReqAll = captureReqHeaders | captureReqBody
captureRespAll = captureRespHeaders | captureRespBody
captureAll = captureReqAll | captureRespAll
)
// captureFieldsByPath overrides the default capture mask for routes carrying
// large binary payloads (audio/image) where storing the full body is wasteful.
var captureFieldsByPath = map[string]captureFields{
"/v1/audio/speech": captureReqAll | captureRespHeaders,
"/v1/audio/voices": captureReqHeaders | captureRespAll,
"/v1/audio/transcriptions": captureReqHeaders | captureRespHeaders | captureRespBody,
"/v1/images/generations": captureReqAll | captureRespHeaders,
"/v1/images/edits": captureReqHeaders | captureRespHeaders,
"/sdapi/v1/txt2img": captureReqAll | captureRespHeaders,
"/sdapi/v1/img2img": captureReqHeaders | captureRespHeaders,
}
// captureFieldsFor returns the capture mask for a request path. Unlisted routes
// (the OpenAI-compatible JSON endpoints) capture everything.
func captureFieldsFor(path string) captureFields {
if cf, ok := captureFieldsByPath[path]; ok {
return cf
}
return captureAll
}
// zstdEncOptions are the shared zstd encoder options for maximum compression.
var zstdEncOptions = []zstd.EOption{
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
}
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
var zstdEncPool = &sync.Pool{
New: func() interface{} {
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
return enc
},
}
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
var zstdDecPool = &sync.Pool{
New: func() interface{} {
dec, _ := zstd.NewReader(nil)
return dec
},
}
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
// Returns the compressed bytes and the original CBOR byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
cborBytes, err := cbor.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
zenc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(zenc)
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
}
// decompressCapture decompresses zstd-compressed CBOR into a ReqRespCapture.
func decompressCapture(data []byte) (*ReqRespCapture, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
cborBytes, err := dec.DecodeAll(data, nil)
if err != nil {
return nil, fmt.Errorf("decompress capture: %w", err)
}
var capture ReqRespCapture
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
return nil, fmt.Errorf("unmarshal capture: %w", err)
}
return &capture, nil
}
// addCapture compresses and stores a capture in the cache. Returns true if the
// capture was stored.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
if !mp.enableCaptures {
return false
}
compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return false
}
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
return false
}
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
return true
}
// getCaptureByID decompresses and unmarshals a capture by ID. Returns nil if
// the capture is not found or decompression fails.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
if mp.captureCache == nil {
return nil
}
data, err := mp.captureCache.Get(id)
if err != nil {
return nil
}
capture, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return capture
}
// sensitiveHeaders lists headers that are redacted in captures.
var sensitiveHeaders = map[string]bool{
"authorization": true,
"proxy-authorization": true,
"cookie": true,
"set-cookie": true,
"x-api-key": true,
}
// headerMap flattens an http.Header to a single-value map.
func headerMap(h http.Header) map[string]string {
m := make(map[string]string, len(h))
for key, values := range h {
if len(values) > 0 {
m[key] = values[0]
}
}
return m
}
// redactHeaders replaces sensitive header values in-place with "[REDACTED]".
func redactHeaders(headers map[string]string) {
for key := range headers {
if sensitiveHeaders[strings.ToLower(key)] {
headers[key] = "[REDACTED]"
}
}
}
+79
View File
@@ -0,0 +1,79 @@
package server
import (
"bytes"
"io"
"testing"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestServer_CaptureCompressRoundtrip(t *testing.T) {
orig := &ReqRespCapture{
ID: 7,
ReqPath: "/v1/chat/completions",
ReqHeaders: map[string]string{"Content-Type": "application/json"},
ReqBody: []byte(`{"model":"m"}`),
RespHeaders: map[string]string{"Content-Type": "application/json"},
RespBody: []byte(`{"usage":{}}`),
}
compressed, uncompressed, err := compressCapture(orig)
if err != nil {
t.Fatalf("compressCapture: %v", err)
}
if uncompressed == 0 || len(compressed) == 0 {
t.Fatalf("unexpected sizes: uncompressed=%d compressed=%d", uncompressed, len(compressed))
}
got, err := decompressCapture(compressed)
if err != nil {
t.Fatalf("decompressCapture: %v", err)
}
if got.ID != orig.ID || got.ReqPath != orig.ReqPath ||
!bytes.Equal(got.ReqBody, orig.ReqBody) || !bytes.Equal(got.RespBody, orig.RespBody) {
t.Fatalf("roundtrip mismatch: %+v", got)
}
}
func TestServer_CaptureStoreAndRetrieve(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
if !mm.enableCaptures {
t.Fatal("captures should be enabled with non-zero buffer")
}
capture := ReqRespCapture{ID: 3, ReqPath: "/v1/chat/completions", ReqBody: []byte("hello")}
if !mm.addCapture(capture) {
t.Fatal("addCapture returned false")
}
got := mm.getCaptureByID(3)
if got == nil || !bytes.Equal(got.ReqBody, []byte("hello")) {
t.Fatalf("getCaptureByID = %+v", got)
}
if mm.getCaptureByID(999) != nil {
t.Fatal("expected nil for unknown capture ID")
}
}
func TestServer_CaptureDisabled(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 0)
if mm.enableCaptures {
t.Fatal("captures should be disabled with zero buffer")
}
if mm.addCapture(ReqRespCapture{ID: 1}) {
t.Fatal("addCapture should return false when disabled")
}
if mm.getCaptureByID(1) != nil {
t.Fatal("getCaptureByID should return nil when disabled")
}
}
func TestServer_CaptureFieldsFor(t *testing.T) {
if got := captureFieldsFor("/v1/chat/completions"); got != captureAll {
t.Fatalf("default = %b, want captureAll", got)
}
if got := captureFieldsFor("/v1/audio/speech"); got != captureReqAll|captureRespHeaders {
t.Fatalf("/v1/audio/speech = %b", got)
}
}
+55
View File
@@ -0,0 +1,55 @@
package server
import (
"net/http"
"golang.org/x/sync/semaphore"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/router"
)
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
// the model config leaves concurrencyLimit unset. Matches the legacy
// proxy.Process default.
const defaultConcurrencyLimit = 10
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
// model-dispatched requests per model. Each model gets a semaphore sized to
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
// immediately acquire a slot is rejected with 429. Models without a local
// config entry (e.g. peer-routed models) are not limited.
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
for id, mc := range cfg.Models {
limit := defaultConcurrencyLimit
if mc.ConcurrencyLimit > 0 {
limit = mc.ConcurrencyLimit
}
semaphores[id] = semaphore.NewWeighted(int64(limit))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := router.FetchContext(r, cfg)
if err != nil {
router.SendError(w, r, router.ErrNoModelInContext)
return
}
// fall through for peer models
sem, ok := semaphores[data.ModelID]
if !ok {
next.ServeHTTP(w, r)
return
}
if !sem.TryAcquire(1) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
defer sem.Release(1)
next.ServeHTTP(w, r)
})
}
}
+75
View File
@@ -0,0 +1,75 @@
package server
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/router"
)
func concurrencyTestReq(model string) *http.Request {
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
}
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
cfg := config.Config{
Models: map[string]config.ModelConfig{
"m1": {ConcurrencyLimit: 1},
},
}
entered := make(chan struct{})
release := make(chan struct{})
var once sync.Once
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
once.Do(func() { close(entered) })
<-release
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
// First request occupies the only slot.
done := make(chan struct{})
go func() {
defer close(done)
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
}()
<-entered
// Second concurrent request is rejected with 429.
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusTooManyRequests {
t.Fatalf("over-limit status = %d, want 429", w.Code)
}
// Once the slot frees, a new request succeeds.
close(release)
<-done
w = httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusOK {
t.Fatalf("post-release status = %d, want 200", w.Code)
}
}
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{}}
called := 0
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called++
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
if w.Code != http.StatusOK || called != 1 {
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
}
}
+205
View File
@@ -0,0 +1,205 @@
package server
import (
"bytes"
"compress/flate"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestServer_DecompressBody(t *testing.T) {
plain := []byte("hello world")
var gz bytes.Buffer
gw := gzip.NewWriter(&gz)
gw.Write(plain)
gw.Close()
var fl bytes.Buffer
fw, _ := flate.NewWriter(&fl, flate.DefaultCompression)
fw.Write(plain)
fw.Close()
cases := []struct {
name string
body []byte
encoding string
}{
{"plain", plain, ""},
{"gzip", gz.Bytes(), "gzip"},
{"deflate", fl.Bytes(), "deflate"},
{"unknown passthrough", plain, "br"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := decompressBody(c.body, c.encoding)
if err != nil {
t.Fatalf("decompressBody: %v", err)
}
if !bytes.Equal(got, plain) {
t.Errorf("got %q, want %q", got, plain)
}
})
}
}
func TestServer_FilterAcceptEncoding(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""},
{"gzip, deflate, br", "gzip, deflate"},
{"br, zstd", ""},
{"gzip;q=1.0", "gzip;q=1.0"},
}
for _, c := range cases {
if got := filterAcceptEncoding(c.in); got != c.want {
t.Errorf("filterAcceptEncoding(%q) = %q, want %q", c.in, got, c.want)
}
}
}
func TestServer_BodyCopier_Flush(t *testing.T) {
bc := newBodyCopier(httptest.NewRecorder())
bc.Write([]byte("data"))
bc.Flush()
if bc.Status() != http.StatusOK {
t.Errorf("status = %d, want 200", bc.Status())
}
}
func TestServer_HeaderMapAndRedact(t *testing.T) {
h := http.Header{
"Content-Type": {"application/json"},
"Authorization": {"Bearer secret"},
"X-Api-Key": {"key123"},
}
m := headerMap(h)
if m["Content-Type"] != "application/json" {
t.Errorf("Content-Type = %q", m["Content-Type"])
}
redactHeaders(m)
if m["Authorization"] != "[REDACTED]" || m["X-Api-Key"] != "[REDACTED]" {
t.Errorf("sensitive headers not redacted: %v", m)
}
if m["Content-Type"] != "application/json" {
t.Error("non-sensitive header should not be redacted")
}
}
func TestServer_StripVersionPrefix(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/v/v1/chat", nil)
stripVersionPrefix(r)
if r.URL.Path != "/v1/chat" {
t.Errorf("path = %q, want /v1/chat", r.URL.Path)
}
r2 := httptest.NewRequest(http.MethodGet, "/v1/chat", nil)
stripVersionPrefix(r2)
if r2.URL.Path != "/v1/chat" {
t.Errorf("path = %q, want unchanged", r2.URL.Path)
}
}
func TestServer_CloseStreams(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.CloseStreams()
select {
case <-s.shutdownCtx.Done():
default:
t.Error("CloseStreams did not cancel shutdown context")
}
s.CloseStreams() // idempotent
}
func TestServer_HandleUIAndFavicon(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
for _, path := range []string{"/ui/", "/favicon.ico"} {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
// The embedded ui_dist only carries placeholder.txt in test builds, so
// these resolve to 404 — the handlers still execute end to end.
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
t.Errorf("%s: status = %d", path, w.Code)
}
}
}
func TestServer_HandleAPIUnloadAll(t *testing.T) {
local := newStubRouter([]string{"m1"}, "")
s := newTestServer(local, newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if local.unloadCalls.Load() != 1 {
t.Errorf("unloadCalls = %d, want 1", local.unloadCalls.Load())
}
}
func TestServer_HandleAPIUnloadModel(t *testing.T) {
local := newStubRouter([]string{"m1"}, "")
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
t.Run("known model", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/m1", nil))
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
t.Run("unknown model 404", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/nope", nil))
if w.Code != http.StatusNotFound {
t.Errorf("status = %d, want 404", w.Code)
}
})
}
func TestServer_HandleAPICapture(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.metrics = newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
s.metrics.addCapture(ReqRespCapture{ID: 42, ReqPath: "/v1/chat/completions"})
t.Run("found", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/42", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if !bytes.Contains(w.Body.Bytes(), []byte("/v1/chat/completions")) {
t.Errorf("body = %q", w.Body.String())
}
})
t.Run("not found", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/999", nil))
if w.Code != http.StatusNotFound {
t.Errorf("status = %d, want 404", w.Code)
}
})
t.Run("invalid id", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/abc", nil))
if w.Code != http.StatusBadRequest {
t.Errorf("status = %d, want 400", w.Code)
}
})
}
+218
View File
@@ -0,0 +1,218 @@
package server
import (
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"strconv"
"strings"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/router"
"github.com/tidwall/sjson"
)
// CreateFilterMiddleware returns middleware that applies per-model request-body
// filters to JSON requests before they are forwarded upstream:
//
// - UseModelName rewrite (issue #69)
// - StripParams removal (issue #174)
// - SetParams injection (issue #453)
// - SetParamsByID per-alias overrides
//
// Non-JSON requests (GET, multipart forms) pass through untouched. The buffered
// body is re-attached with Content-Length / Transfer-Encoding cleanup so the
// downstream reverse proxy forwards the correct bytes (see issue #11).
func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Content-Type"), "application/json") {
next.ServeHTTP(w, r)
return
}
data, err := router.FetchContext(r, cfg)
if err != nil {
router.SendError(w, r, router.ErrNoModelInContext)
return
}
useModelName, filters, ok := resolveFilters(cfg, data.Model)
if !ok {
next.ServeHTTP(w, r)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
return
}
body, err = applyFilters(body, data.Model, useModelName, filters)
if err != nil {
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
return
}
r.Body = io.NopCloser(bytes.NewReader(body))
r.Header.Del("Transfer-Encoding")
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
r.ContentLength = int64(len(body))
next.ServeHTTP(w, r)
})
}
}
// CreateFormFilterMiddleware returns middleware that applies the UseModelName
// rewrite (issue #69) to multipart/form-data requests before they are forwarded
// upstream. JSON-body filters (StripParams, SetParams) do not apply to form
// endpoints; only the "model" field is rewritten.
//
// Non-multipart requests pass through untouched. When a rewrite is needed the
// form is reconstructed and re-attached with Content-Type / Content-Length
// cleanup so the downstream reverse proxy forwards the correct bytes.
func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
next.ServeHTTP(w, r)
return
}
data, err := router.FetchContext(r, cfg)
if err != nil {
router.SendError(w, r, router.ErrNoModelInContext)
return
}
useModelName, _, ok := resolveFilters(cfg, data.Model)
if !ok || useModelName == "" {
next.ServeHTTP(w, r)
return
}
if err := r.ParseMultipartForm(32 << 20); err != nil {
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
if err != nil {
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
return
}
r.Body = io.NopCloser(bytes.NewReader(body))
r.MultipartForm = nil
r.Header.Del("Transfer-Encoding")
r.Header.Set("Content-Type", contentType)
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
r.ContentLength = int64(len(body))
next.ServeHTTP(w, r)
})
}
}
// rewriteMultipartModel reconstructs a multipart form, replacing the "model"
// field value with useModelName. It returns the encoded body and the matching
// Content-Type header (which carries the generated boundary).
func rewriteMultipartModel(form *multipart.Form, useModelName string) ([]byte, string, error) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
for key, values := range form.Value {
for _, value := range values {
if key == "model" {
value = useModelName
}
field, err := mw.CreateFormField(key)
if err != nil {
return nil, "", fmt.Errorf("error recreating form field %s: %w", key, err)
}
if _, err := field.Write([]byte(value)); err != nil {
return nil, "", fmt.Errorf("error writing form field %s: %w", key, err)
}
}
}
for key, headers := range form.File {
for _, fh := range headers {
part, err := mw.CreateFormFile(key, fh.Filename)
if err != nil {
return nil, "", fmt.Errorf("error recreating form file %s: %w", key, err)
}
file, err := fh.Open()
if err != nil {
return nil, "", fmt.Errorf("error opening uploaded file %s: %w", key, err)
}
if _, err := io.Copy(part, file); err != nil {
file.Close()
return nil, "", fmt.Errorf("error copying file data %s: %w", key, err)
}
file.Close()
}
}
if err := mw.Close(); err != nil {
return nil, "", fmt.Errorf("error finalizing multipart form: %w", err)
}
return buf.Bytes(), mw.FormDataContentType(), nil
}
// resolveFilters returns the filter settings for a requested model. UseModelName
// only applies to local models; peers carry filters but no name rewrite.
func resolveFilters(cfg config.Config, requested string) (useModelName string, filters config.Filters, ok bool) {
if realName, found := cfg.RealModelName(requested); found {
mc := cfg.Models[realName]
return mc.UseModelName, mc.Filters.Filters, true
}
for _, peer := range cfg.Peers {
for _, m := range peer.Models {
if m == requested {
return "", peer.Filters, true
}
}
}
return "", config.Filters{}, false
}
// applyFilters rewrites the JSON body in place. Order matches the legacy
// ProxyManager: useModelName, stripParams, setParams, then setParamsByID (which
// can override setParams).
func applyFilters(body []byte, requested, useModelName string, f config.Filters) ([]byte, error) {
var err error
if useModelName != "" {
if body, err = sjson.SetBytes(body, "model", useModelName); err != nil {
return nil, fmt.Errorf("error rewriting model name in JSON: %w", err)
}
}
for _, param := range f.SanitizedStripParams() {
if body, err = sjson.DeleteBytes(body, param); err != nil {
return nil, fmt.Errorf("error stripping parameter %s from request", param)
}
}
setParams, setKeys := f.SanitizedSetParams()
for _, key := range setKeys {
if body, err = sjson.SetBytes(body, key, setParams[key]); err != nil {
return nil, fmt.Errorf("error setting parameter %s in request", key)
}
}
byID, byIDKeys := f.SanitizedSetParamsByID(requested)
for _, key := range byIDKeys {
if body, err = sjson.SetBytes(body, key, byID[key]); err != nil {
return nil, fmt.Errorf("error setting parameter %s in request", key)
}
}
return body, nil
}
+132
View File
@@ -0,0 +1,132 @@
package server
import (
"bytes"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/tidwall/gjson"
)
func TestServer_ApplyFilters(t *testing.T) {
t.Run("useModelName rewrite", func(t *testing.T) {
out, err := applyFilters([]byte(`{"model":"alias","temp":1}`), "alias", "real-model", config.Filters{})
if err != nil {
t.Fatalf("applyFilters: %v", err)
}
if got := gjson.GetBytes(out, "model").String(); got != "real-model" {
t.Errorf("model = %q, want real-model", got)
}
})
t.Run("strip and set params", func(t *testing.T) {
f := config.Filters{
StripParams: "temperature",
SetParams: map[string]any{"top_p": 0.9},
}
out, err := applyFilters([]byte(`{"model":"m","temperature":0.7}`), "m", "", f)
if err != nil {
t.Fatalf("applyFilters: %v", err)
}
if gjson.GetBytes(out, "temperature").Exists() {
t.Error("temperature should be stripped")
}
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.9 {
t.Errorf("top_p = %v, want 0.9", got)
}
})
t.Run("setParamsByID overrides setParams", func(t *testing.T) {
f := config.Filters{
SetParams: map[string]any{"top_p": 0.5},
SetParamsByID: map[string]map[string]any{"alias": {"top_p": 0.1}},
}
out, err := applyFilters([]byte(`{"model":"alias"}`), "alias", "", f)
if err != nil {
t.Fatalf("applyFilters: %v", err)
}
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.1 {
t.Errorf("top_p = %v, want 0.1", got)
}
})
}
func TestServer_RewriteMultipartModel(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
mw.WriteField("model", "old-name")
mw.WriteField("language", "en")
fw, _ := mw.CreateFormFile("file", "audio.wav")
fw.Write([]byte("RIFFdata"))
mw.Close()
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
r.Header.Set("Content-Type", mw.FormDataContentType())
if err := r.ParseMultipartForm(32 << 20); err != nil {
t.Fatalf("ParseMultipartForm: %v", err)
}
body, contentType, err := rewriteMultipartModel(r.MultipartForm, "new-name")
if err != nil {
t.Fatalf("rewriteMultipartModel: %v", err)
}
parsed, err := multipart.NewReader(bytes.NewReader(body), boundaryOf(t, contentType)).ReadForm(32 << 20)
if err != nil {
t.Fatalf("re-parse: %v", err)
}
if got := parsed.Value["model"][0]; got != "new-name" {
t.Errorf("model = %q, want new-name", got)
}
if got := parsed.Value["language"][0]; got != "en" {
t.Errorf("language = %q, want en", got)
}
fh := parsed.File["file"][0]
f, _ := fh.Open()
data, _ := io.ReadAll(f)
f.Close()
if string(data) != "RIFFdata" {
t.Errorf("file data = %q, want RIFFdata", data)
}
}
func boundaryOf(t *testing.T, contentType string) string {
t.Helper()
_, params, ok := strings.Cut(contentType, "boundary=")
if !ok {
t.Fatalf("no boundary in %q", contentType)
}
return params
}
func TestServer_FormFilterMiddleware(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{
"whisper": {UseModelName: "whisper-large-v3"},
}}
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
mw.WriteField("model", "whisper")
fw, _ := mw.CreateFormFile("file", "a.wav")
fw.Write([]byte("xx"))
mw.Close()
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
r.Header.Set("Content-Type", mw.FormDataContentType())
var gotModel string
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseMultipartForm(32 << 20)
gotModel = r.MultipartForm.Value["model"][0]
})
CreateFormFilterMiddleware(cfg)(final).ServeHTTP(httptest.NewRecorder(), r)
if gotModel != "whisper-large-v3" {
t.Errorf("model rewritten to %q, want whisper-large-v3", gotModel)
}
}
+33
View File
@@ -0,0 +1,33 @@
package server
import (
"net/http"
"sync/atomic"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// inflightCounter tracks the number of in-flight model-dispatched requests.
type inflightCounter struct {
total atomic.Int64
}
func (c *inflightCounter) Increment() int64 { return c.total.Add(1) }
func (c *inflightCounter) Decrement() int64 { return c.total.Add(-1) }
func (c *inflightCounter) Current() int64 { return c.total.Load() }
// CreateInflightMiddleware returns middleware that increments the counter on
// entry and decrements on exit, emitting an InFlightRequestsEvent for each.
func CreateInflightMiddleware(c *inflightCounter) chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Increment())})
defer func() {
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Decrement())})
}()
next.ServeHTTP(w, r)
})
}
}
+222
View File
@@ -0,0 +1,222 @@
package server
import (
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/router"
)
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
// wiring each one's output per the logToStdout config value. The proxy and
// upstream monitors write into muxlog (rather than os.Stdout directly) so
// muxlog accumulates a combined history for the /logs endpoints, while each
// monitor keeps its own per-source history and event subscribers.
//
// Behaviour matches the legacy ProxyManager:
//
// - none: everything discarded
// - both: proxy + upstream both routed to muxlog -> stdout
// - upstream: only upstream routed to muxlog -> stdout; proxy discarded
// - proxy: only proxy routed to muxlog -> stdout; upstream discarded
//
// An empty or unrecognised value behaves like "proxy".
func NewLoggers(logToStdout string) (muxlog, proxylog, upstreamlog *logmon.Monitor) {
switch logToStdout {
case config.LogToStdoutNone:
muxlog = logmon.NewWriter(io.Discard)
proxylog = logmon.NewWriter(io.Discard)
upstreamlog = logmon.NewWriter(io.Discard)
case config.LogToStdoutBoth:
muxlog = logmon.NewWriter(os.Stdout)
proxylog = logmon.NewWriter(muxlog)
upstreamlog = logmon.NewWriter(muxlog)
case config.LogToStdoutUpstream:
muxlog = logmon.NewWriter(os.Stdout)
proxylog = logmon.NewWriter(io.Discard)
upstreamlog = logmon.NewWriter(muxlog)
default:
// config.LogToStdoutProxy, and the fallback for an unset value.
muxlog = logmon.NewWriter(os.Stdout)
proxylog = logmon.NewWriter(muxlog)
upstreamlog = logmon.NewWriter(io.Discard)
}
return muxlog, proxylog, upstreamlog
}
// handleLogs serves the historical proxy/upstream log. HTML clients are
// redirected to the UI.
func (s *Server) handleLogs(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.Header.Get("Accept"), "text/html") {
http.Redirect(w, r, "/ui/", http.StatusFound)
return
}
w.Header().Set("Content-Type", "text/plain")
w.Write(s.muxlog.GetHistory())
}
// getLogger resolves a log monitor by id. An empty id maps to the combined
// muxlog; "proxy" and "upstream" select the respective monitors.
func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
switch logMonitorID {
case "":
return s.muxlog, nil
case "proxy":
return s.proxylog, nil
case "upstream":
return s.upstreamlog, nil
default:
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
if log, ok := s.local.ProcessLogger(modelID); ok {
return log, nil
}
}
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
}
}
// handleLogStream tails a log monitor: it writes the history then streams live
// log data until the client disconnects or the server shuts down.
func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("X-Content-Type-Options", "nosniff")
// prevent nginx from buffering streamed logs
w.Header().Set("X-Accel-Buffering", "no")
logMonitorID := strings.TrimPrefix(r.PathValue("logMonitorID"), "/")
// Strip a query string if it leaked into the path segment.
if idx := strings.Index(logMonitorID, "?"); idx != -1 {
logMonitorID = logMonitorID[:idx]
}
logger, err := s.getLogger(logMonitorID)
if err != nil {
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
return
}
flusher, ok := w.(http.Flusher)
if !ok {
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
return
}
_, skipHistory := r.URL.Query()["no-history"]
if !skipHistory {
if history := logger.GetHistory(); len(history) != 0 {
w.Write(history)
flusher.Flush()
}
}
sendChan := make(chan []byte, 10)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
cancelSub := logger.OnLogData(func(data []byte) {
select {
case sendChan <- data:
case <-ctx.Done():
default:
}
})
defer cancelSub()
for {
select {
case <-r.Context().Done():
return
case <-s.shutdownCtx.Done():
return
case data := <-sendChan:
w.Write(data)
flusher.Flush()
}
}
}
// requestLogPathSkips lists path prefixes excluded from the access log because
// they are polled frequently and would drown out useful entries.
var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"}
// statusRecorder wraps an http.ResponseWriter to capture the response status
// code and the number of body bytes written, so the access log can report
// them. Flush is forwarded so streaming handlers (SSE) still work.
type statusRecorder struct {
http.ResponseWriter
status int
size int
}
func (sr *statusRecorder) WriteHeader(code int) {
sr.status = code
sr.ResponseWriter.WriteHeader(code)
}
func (sr *statusRecorder) Write(b []byte) (int, error) {
n, err := sr.ResponseWriter.Write(b)
sr.size += n
return n, err
}
func (sr *statusRecorder) Flush() {
if f, ok := sr.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// clientIP resolves the originating client address, preferring proxy headers
// over the raw connection address.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if first, _, found := strings.Cut(xff, ","); found {
return strings.TrimSpace(first)
}
return strings.TrimSpace(xff)
}
if xr := r.Header.Get("X-Real-IP"); xr != "" {
return strings.TrimSpace(xr)
}
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return host
}
return r.RemoteAddr
}
// CreateRequestLogMiddleware returns middleware that records one access-log
// line per request to proxylog, in the legacy format:
//
// clientIP "METHOD PATH PROTO" status bodySize "UA" duration
//
// Frequently-polled health/metrics paths are skipped. The path is captured
// before next runs because /upstream rewrites the request URL in place.
func CreateRequestLogMiddleware(proxylog *logmon.Monitor) chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, prefix := range requestLogPathSkips {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
}
start := time.Now()
ip, method, path, proto, ua := clientIP(r), r.Method, r.URL.Path, r.Proto, r.UserAgent()
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rec, r)
proxylog.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
ip, method, path, proto, rec.status, rec.size, ua, time.Since(start))
})
}
}
+137
View File
@@ -0,0 +1,137 @@
package server
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestServer_NewLoggers(t *testing.T) {
t.Run("proxy mode routes proxy into muxlog, discards upstream", func(t *testing.T) {
mux, proxy, upstream := NewLoggers(config.LogToStdoutProxy)
proxy.Info("PROXYLINE")
upstream.Info("UPSTREAMLINE")
h := string(mux.GetHistory())
if !strings.Contains(h, "PROXYLINE") {
t.Errorf("muxlog missing proxy line: %q", h)
}
if strings.Contains(h, "UPSTREAMLINE") {
t.Errorf("muxlog should not contain upstream line: %q", h)
}
})
t.Run("both mode routes proxy and upstream into muxlog", func(t *testing.T) {
mux, proxy, upstream := NewLoggers(config.LogToStdoutBoth)
proxy.Info("PROXYLINE")
upstream.Info("UPSTREAMLINE")
h := string(mux.GetHistory())
if !strings.Contains(h, "PROXYLINE") || !strings.Contains(h, "UPSTREAMLINE") {
t.Errorf("muxlog history = %q", h)
}
})
t.Run("none mode discards everything from muxlog", func(t *testing.T) {
mux, proxy, upstream := NewLoggers(config.LogToStdoutNone)
proxy.Info("PROXYLINE")
upstream.Info("UPSTREAMLINE")
if len(mux.GetHistory()) != 0 {
t.Errorf("muxlog should be empty, got %q", mux.GetHistory())
}
})
}
func TestServer_HandleLogs_Plain(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.muxlog.Write([]byte("a log line"))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if ct := w.Header().Get("Content-Type"); ct != "text/plain" {
t.Errorf("Content-Type = %q, want text/plain", ct)
}
if w.Body.String() != "a log line" {
t.Errorf("body = %q", w.Body.String())
}
}
func TestServer_HandleLogs_HTMLRedirect(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
req.Header.Set("Accept", "text/html")
w := httptest.NewRecorder()
s.ServeHTTP(w, req)
if w.Code != http.StatusFound {
t.Fatalf("status = %d, want 302", w.Code)
}
if got := w.Header().Get("Location"); got != "/ui/" {
t.Errorf("Location = %q, want /ui/", got)
}
}
func TestServer_ClientIP(t *testing.T) {
cases := []struct {
name string
setup func(*http.Request)
want string
}{
{"remote addr", func(r *http.Request) { r.RemoteAddr = "10.0.0.5:1234" }, "10.0.0.5"},
{"x-forwarded-for", func(r *http.Request) {
r.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8")
}, "1.2.3.4"},
{"x-real-ip", func(r *http.Request) { r.Header.Set("X-Real-IP", "9.9.9.9") }, "9.9.9.9"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.RemoteAddr = ""
c.setup(r)
if got := clientIP(r); got != c.want {
t.Errorf("clientIP() = %q, want %q", got, c.want)
}
})
}
}
func TestServer_RequestLogMiddleware(t *testing.T) {
proxylog := logmon.NewWriter(io.Discard)
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte("hello"))
})
mw := CreateRequestLogMiddleware(proxylog)
t.Run("logs request", func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
r.RemoteAddr = "192.168.1.1:5000"
mw(final).ServeHTTP(httptest.NewRecorder(), r)
line := string(proxylog.GetHistory())
for _, want := range []string{"192.168.1.1", "POST /v1/chat/completions", "201", "5"} {
if !strings.Contains(line, want) {
t.Errorf("log line %q missing %q", line, want)
}
}
})
for _, path := range []string{"/wol-health", "/api/performance", "/metrics"} {
t.Run("skips "+path, func(t *testing.T) {
skipLog := logmon.NewWriter(io.Discard)
skipMW := CreateRequestLogMiddleware(skipLog)
skipMW(final).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, path, nil))
if len(skipLog.GetHistory()) != 0 {
t.Errorf("%s should not be logged; got %q", path, skipLog.GetHistory())
}
})
}
}
+450
View File
@@ -0,0 +1,450 @@
package server
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/mostlygeek/llama-swap/internal/cache"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/ring"
"github.com/mostlygeek/llama-swap/internal/shared"
"github.com/tidwall/gjson"
)
// TokenMetrics holds token usage and performance metrics.
type TokenMetrics struct {
CachedTokens int `json:"cache_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"`
}
// ActivityLogEntry represents parsed token statistics from llama-server logs.
type ActivityLogEntry struct {
ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
ReqPath string `json:"req_path"`
RespContentType string `json:"resp_content_type"`
RespStatusCode int `json:"resp_status_code"`
Tokens TokenMetrics `json:"tokens"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
}
// ActivityLogEvent carries a single activity log entry to event subscribers.
type ActivityLogEvent struct {
Metrics ActivityLogEntry
}
func (e ActivityLogEvent) Type() uint32 {
return shared.ActivityLogEventID
}
// metricsMonitor parses upstream responses for token statistics, keeps a
// bounded in-memory ring of recent activity, and (when captures are enabled)
// stores zstd+CBOR-compressed request/response captures in a sized cache.
type metricsMonitor struct {
mu sync.RWMutex
metrics ring.Buffer[ActivityLogEntry]
nextID int
logger *logmon.Monitor
enableCaptures bool
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
}
// newMetricsMonitor creates a metricsMonitor retaining up to maxMetrics entries.
// captureBufferMB is the capture buffer size in megabytes; 0 disables captures.
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
if maxMetrics <= 0 {
maxMetrics = 1000
}
mm := &metricsMonitor{
logger: logger,
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
enableCaptures: captureBufferMB > 0,
}
if captureBufferMB > 0 {
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
}
return mm
}
// queueMetrics adds a metric to the ring and returns its assigned ID.
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
mp.mu.Lock()
defer mp.mu.Unlock()
metric.ID = mp.nextID
mp.nextID++
mp.metrics.Push(metric)
return metric.ID
}
// emitMetric publishes an ActivityLogEvent for the given metric.
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
event.Emit(ActivityLogEvent{Metrics: metric})
}
// getMetrics returns a copy of the current metrics.
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
mp.mu.RLock()
defer mp.mu.RUnlock()
result := mp.metrics.Slice()
if result == nil {
return []ActivityLogEntry{}
}
if mp.captureCache != nil {
for i := range result {
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
}
}
return result
}
// getMetricsJSON returns the current metrics as a JSON array.
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
return json.Marshal(mp.getMetrics())
}
// record parses a completed response body and stores/emits an activity entry.
// When captures are enabled, a zstd+CBOR capture is stored for successful
// requests, with cf controlling which request/response parts are retained.
// reqBody and reqHeaders are the request data buffered before dispatch.
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
tm := ActivityLogEntry{
Timestamp: time.Now(),
Model: modelID,
ReqPath: r.URL.Path,
RespContentType: recorder.Header().Get("Content-Type"),
RespStatusCode: recorder.Status(),
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
}
queueAndEmit := func() {
tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
}
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
queueAndEmit()
return
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics: empty body, recording minimal metrics")
queueAndEmit()
return
}
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
decoded, err := decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
queueAndEmit()
return
}
body = decoded
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, r.URL.Path)
} else {
tm.Tokens = parsed.Tokens
tm.DurationMs = parsed.DurationMs
}
} else if gjson.ValidBytes(body) {
parsed := gjson.ParseBytes(body)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
// /infill responses are arrays; timings live in the last element (#463).
if strings.HasPrefix(r.URL.Path, "/infill") {
if arr := parsed.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
}
if usage.Exists() || timings.Exists() {
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, r.URL.Path)
} else {
tm.Tokens = parsedMetrics.Tokens
tm.DurationMs = parsedMetrics.DurationMs
}
}
} else {
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", r.URL.Path)
}
tm.ID = mp.queueMetrics(tm)
if mp.enableCaptures {
capture := ReqRespCapture{
ID: tm.ID,
ReqPath: r.URL.Path,
ReqHeaders: reqHeaders,
}
if cf&captureReqBody != 0 {
capture.ReqBody = reqBody
}
if cf&captureRespHeaders != 0 {
capture.RespHeaders = headerMap(recorder.Header())
redactHeaders(capture.RespHeaders)
delete(capture.RespHeaders, "Content-Encoding")
}
if cf&captureRespBody != 0 {
capture.RespBody = body
}
if mp.addCapture(capture) {
tm.HasCapture = true
}
}
mp.emitMetric(tm)
}
// usagePaths lists the JSON paths where a per-event usage object can live.
var usagePaths = []string{"usage", "response.usage", "message.usage"}
// extractUsageTokens reads input/output/cached token counts from a usage
// gjson.Result, handling the field-name differences across endpoints.
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
cached = -1
if !usage.Exists() {
return
}
if v := usage.Get("prompt_tokens"); v.Exists() {
input = v.Int()
ok = true
} else if v := usage.Get("input_tokens"); v.Exists() {
input = v.Int()
ok = true
}
if v := usage.Get("completion_tokens"); v.Exists() {
output = v.Int()
ok = true
} else if v := usage.Get("output_tokens"); v.Exists() {
output = v.Int()
ok = true
}
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
cached = v.Int()
ok = true
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
cached = v.Int()
ok = true
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
cached = v.Int()
ok = true
}
return
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
var (
inputTokens, outputTokens int64
cachedTokens int64 = -1
hasAny bool
timings gjson.Result
)
prefix := []byte("data:")
for offset := 0; offset < len(body); {
nl := bytes.IndexByte(body[offset:], '\n')
var line []byte
if nl == -1 {
line = body[offset:]
offset = len(body)
} else {
line = body[offset : offset+nl]
offset += nl + 1
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
continue
}
if !gjson.ValidBytes(data) {
continue
}
parsed := gjson.ParseBytes(data)
for _, path := range usagePaths {
u := parsed.Get(path)
if !u.Exists() {
continue
}
i, o, c, ok := extractUsageTokens(u)
if !ok {
continue
}
hasAny = true
if i > 0 {
inputTokens = i
}
if o > 0 {
outputTokens = o
}
if c >= 0 {
cachedTokens = c
}
}
if t := parsed.Get("timings"); t.Exists() {
timings = t
hasAny = true
}
}
if !hasAny {
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
}
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
}
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
input, output, cached, _ := extractUsageTokens(usage)
return buildMetrics(modelID, start, input, output, cached, timings), nil
}
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
// optional llama-server timings (which override input/output and provide rates).
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
wallDurationMs := int(time.Since(start).Milliseconds())
durationMs := wallDurationMs
tokensPerSecond := -1.0
promptPerSecond := -1.0
if timings.Exists() {
inputTokens = timings.Get("prompt_n").Int()
outputTokens = timings.Get("predicted_n").Int()
promptPerSecond = timings.Get("prompt_per_second").Float()
tokensPerSecond = timings.Get("predicted_per_second").Float()
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
if timingsDurationMs > durationMs {
durationMs = timingsDurationMs
}
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = cachedValue.Int()
}
}
return ActivityLogEntry{
Timestamp: time.Now(),
Model: modelID,
Tokens: TokenMetrics{
CachedTokens: int(cachedTokens),
InputTokens: int(inputTokens),
OutputTokens: int(outputTokens),
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
},
DurationMs: durationMs,
}
}
// decompressBody decompresses the body based on the Content-Encoding header.
func decompressBody(body []byte, encoding string) ([]byte, error) {
switch strings.ToLower(strings.TrimSpace(encoding)) {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
case "deflate":
reader := flate.NewReader(bytes.NewReader(body))
defer reader.Close()
return io.ReadAll(reader)
default:
return body, nil
}
}
// filterAcceptEncoding filters Accept-Encoding to only gzip/deflate so response
// bodies remain decompressible for metrics parsing.
func filterAcceptEncoding(acceptEncoding string) string {
if acceptEncoding == "" {
return ""
}
supported := map[string]bool{"gzip": true, "deflate": true}
var filtered []string
for part := range strings.SplitSeq(acceptEncoding, ",") {
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
if supported[strings.ToLower(encoding)] {
filtered = append(filtered, strings.TrimSpace(part))
}
}
return strings.Join(filtered, ", ")
}
// responseBodyCopier tees the upstream response to the client while buffering
// it for metrics parsing. Status defaults to 200 until WriteHeader is called.
type responseBodyCopier struct {
http.ResponseWriter
body *bytes.Buffer
tee io.Writer
status int
wroteHeader bool
start time.Time
}
func newBodyCopier(w http.ResponseWriter) *responseBodyCopier {
buf := &bytes.Buffer{}
return &responseBodyCopier{
ResponseWriter: w,
body: buf,
tee: io.MultiWriter(w, buf),
status: http.StatusOK,
start: time.Now(),
}
}
func (w *responseBodyCopier) Write(b []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.tee.Write(b)
}
func (w *responseBodyCopier) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
w.status = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
// Flush forwards to the underlying writer so streaming responses still flush.
func (w *responseBodyCopier) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (w *responseBodyCopier) Status() int { return w.status }
func (w *responseBodyCopier) StartTime() time.Time { return w.start }
+62
View File
@@ -0,0 +1,62 @@
package server
import (
"bytes"
"io"
"net/http"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/router"
)
// CreateMetricsMiddleware returns middleware that records token metrics for
// model-dispatched POST requests. It resolves the model, tees the response into
// a buffer, and parses token usage once the upstream handler returns.
func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if mm == nil || r.Method != http.MethodPost {
next.ServeHTTP(w, r)
return
}
// Resolve the model now so downstream dispatch hits the context
// fast path; FetchContext restores the request body.
data, err := router.FetchContext(r, cfg)
if err != nil {
router.SendError(w, r, router.ErrNoModelInContext)
return
}
// Buffer the request body/headers for capture before dispatch
// consumes them.
cf := captureFieldsFor(r.URL.Path)
var reqBody []byte
var reqHeaders map[string]string
if mm.enableCaptures {
if cf&captureReqBody != 0 && r.Body != nil {
if buffered, err := io.ReadAll(r.Body); err == nil {
reqBody = buffered
r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(reqBody))
}
}
if cf&captureReqHeaders != 0 {
reqHeaders = headerMap(r.Header)
redactHeaders(reqHeaders)
}
}
// Restrict Accept-Encoding to encodings we can decompress so the
// buffered response body stays parseable.
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
r.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
}
recorder := newBodyCopier(w)
next.ServeHTTP(recorder, r)
mm.record(data.ModelID, r, recorder, cf, reqBody, reqHeaders)
})
}
}
+74
View File
@@ -0,0 +1,74 @@
package server
import (
"testing"
"time"
"github.com/tidwall/gjson"
)
func TestServer_ParseMetrics_ChatCompletions(t *testing.T) {
body := `{"usage":{"prompt_tokens":12,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}`
parsed := gjson.Parse(body)
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
if err != nil {
t.Fatalf("parseMetrics: %v", err)
}
if entry.Tokens.InputTokens != 12 || entry.Tokens.OutputTokens != 7 || entry.Tokens.CachedTokens != 4 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
}
func TestServer_ParseMetrics_Timings(t *testing.T) {
body := `{"timings":{"prompt_n":20,"predicted_n":50,"prompt_per_second":100.0,"predicted_per_second":40.0,"prompt_ms":200,"predicted_ms":1250,"cache_n":8}}`
parsed := gjson.Parse(body)
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
if err != nil {
t.Fatalf("parseMetrics: %v", err)
}
if entry.Tokens.InputTokens != 20 || entry.Tokens.OutputTokens != 50 || entry.Tokens.CachedTokens != 8 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
if entry.Tokens.TokensPerSecond != 40.0 || entry.Tokens.PromptPerSecond != 100.0 {
t.Fatalf("rates = %+v", entry.Tokens)
}
if entry.DurationMs != 1450 {
t.Fatalf("DurationMs = %d, want 1450", entry.DurationMs)
}
}
func TestServer_ProcessStreamingResponse(t *testing.T) {
body := []byte("data: {\"choices\":[{}]}\n\n" +
"data: {\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":33}}\n\n" +
"data: [DONE]\n\n")
entry, err := processStreamingResponse("m", time.Now(), body)
if err != nil {
t.Fatalf("processStreamingResponse: %v", err)
}
if entry.Tokens.InputTokens != 15 || entry.Tokens.OutputTokens != 33 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
}
func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
if _, err := processStreamingResponse("m", time.Now(), []byte("data: [DONE]\n\n")); err == nil {
t.Fatal("expected error for stream with no usage data")
}
}
func TestServer_ParseMetrics_Infill(t *testing.T) {
// /infill responses are arrays; timings live in the last element.
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
parsed := gjson.Parse(body)
timings := parsed.Get("timings")
if arr := parsed.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), timings)
if err != nil {
t.Fatalf("parseMetrics: %v", err)
}
if entry.Tokens.InputTokens != 5 || entry.Tokens.OutputTokens != 9 {
t.Fatalf("tokens = %+v", entry.Tokens)
}
}
+290
View File
@@ -0,0 +1,290 @@
package server
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/perf"
"github.com/mostlygeek/llama-swap/internal/router"
)
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
// dispatch. It supersedes router.Server: it builds the local and peer routers
// directly and dispatches between them itself.
type Server struct {
cfg config.Config
muxlog *logmon.Monitor
proxylog *logmon.Monitor
upstreamlog *logmon.Monitor
perf *perf.Monitor
inflight *inflightCounter
metrics *metricsMonitor
build BuildInfo
local router.LocalRouter
peer router.Router
mux *http.ServeMux
handler http.Handler
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
}
// modelPostJSONRoutes are endpoints with a model id in the JSON request body.
var modelPostJSONRoutes = []string{
"/v1/chat/completions",
"/v1/responses",
"/v1/completions",
"/v1/messages",
"/v1/messages/count_tokens",
"/v1/embeddings",
"/reranking",
"/rerank",
"/v1/rerank",
"/v1/reranking",
"/infill",
"/completion",
"/v1/audio/speech",
"/v1/audio/voices",
"/v1/images/generations",
"/sdapi/v1/txt2img",
"/sdapi/v1/img2img",
// versionless routes, the /v/ is stripped before the request is forwarded upstream
// see issue #728
"/v/chat/completions",
"/v/responses",
"/v/completions",
"/v/messages",
"/v/messages/count_tokens",
"/v/embeddings",
"/v/rerank",
"/v/reranking",
}
// modelPostFormRoutes are multipart/form-data endpoints with a model id in the form data
var modelPostFormRoutes = []string{
"/v1/audio/transcriptions",
"/v1/images/edits",
}
// modelGetRoutes are model-dispatched GET endpoints (the model arrives as a
// query parameter).
var modelGetRoutes = []string{
"/v1/audio/voices",
"/sdapi/v1/loras",
}
// BuildInfo carries version metadata surfaced by GET /api/version.
type BuildInfo struct {
Version string
Commit string
Date string
}
func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, upstreamlog *logmon.Monitor, perfMon *perf.Monitor, build BuildInfo) (*Server, error) {
var local router.LocalRouter
var err error
if cfg.Matrix != nil {
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
if err != nil {
return nil, fmt.Errorf("creating matrix router: %w", err)
}
} else {
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
if err != nil {
return nil, fmt.Errorf("creating group router: %w", err)
}
}
peer, err := router.NewPeer(cfg, proxylog)
if err != nil {
return nil, fmt.Errorf("creating peer router: %w", err)
}
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
s := &Server{
cfg: cfg,
muxlog: muxlog,
proxylog: proxylog,
upstreamlog: upstreamlog,
perf: perfMon,
inflight: &inflightCounter{},
metrics: newMetricsMonitor(proxylog, cfg.MetricsMaxInMemory, cfg.CaptureBuffer),
build: build,
local: local,
peer: peer,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
}
s.routes()
s.startPreload()
return s, nil
}
// localPeerHandler dispatches a model-routed request to the local or peer
// router. The model is resolved once via router.FetchContext.
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
stripVersionPrefix(r)
data, err := router.FetchContext(r, s.cfg)
if err != nil {
router.SendError(w, r, router.ErrNoModelInContext)
return
}
switch {
case s.local.Handles(data.ModelID):
s.proxylog.Debugf("dispatch: using local process for model: %s", data.ModelID)
s.local.ServeHTTP(w, r)
case s.peer.Handles(data.ModelID):
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
s.peer.ServeHTTP(w, r)
default:
router.SendError(w, r, router.ErrNoRouterFound)
}
}
// stripVersionPrefix rewrites versionless /v/... requests to their /... form
// before forwarding upstream (issue #728).
func stripVersionPrefix(r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/v/") {
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/v")
}
}
// routes builds the mux, registers every route, and wraps the mux with the
// global CORS middleware.
func (s *Server) routes() {
authMW := CreateAuthMiddleware(s.cfg)
filterMW := CreateFilterMiddleware(s.cfg)
formFilterMW := CreateFormFilterMiddleware(s.cfg)
// Model-dispatched routes get auth + per-model concurrency limiting + body
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
// other's Content-Type. Both run before the metrics middleware so it buffers
// the rewritten body.
modelChain := chain.New(
authMW,
CreateConcurrencyMiddleware(s.cfg),
filterMW,
formFilterMW,
CreateInflightMiddleware(s.inflight),
CreateMetricsMiddleware(s.metrics, s.cfg),
)
// Custom endpoints only need auth.
apiChain := chain.New(authMW)
mux := http.NewServeMux()
dispatch := http.HandlerFunc(s.localPeerHandler)
for _, path := range modelPostJSONRoutes {
mux.Handle("POST "+path, modelChain.Then(dispatch))
}
for _, path := range modelPostFormRoutes {
mux.Handle("POST "+path, modelChain.Then(dispatch))
}
for _, path := range modelGetRoutes {
mux.Handle("GET "+path, modelChain.Then(dispatch))
}
// llama-swap API + custom endpoints.
mux.Handle("GET /v1/models", apiChain.ThenFunc(s.handleListModels))
mux.Handle("GET /logs", apiChain.ThenFunc(s.handleLogs))
mux.Handle("GET /logs/stream", apiChain.ThenFunc(s.handleLogStream))
mux.Handle("GET /logs/stream/{logMonitorID...}", apiChain.ThenFunc(s.handleLogStream))
mux.HandleFunc("GET /health", handleHealth)
mux.HandleFunc("GET /wol-health", handleHealth)
mux.HandleFunc("GET /{$}", handleRootRedirect)
// Embedded UI.
mux.HandleFunc("GET /ui/", s.handleUI)
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
// Prometheus metrics (no auth, matches the legacy endpoint).
mux.HandleFunc("GET /metrics", s.handleMetrics)
// Operations endpoints.
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
// Upstream passthrough.
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
// API group (API-key protected) consumed by the UI.
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
mux.Handle("POST /api/models/unload/{model...}", apiChain.ThenFunc(s.handleAPIUnloadModel))
mux.Handle("GET /api/events", apiChain.ThenFunc(s.handleAPIEvents))
mux.Handle("GET /api/metrics", apiChain.ThenFunc(s.handleAPIMetrics))
mux.Handle("GET /api/performance", apiChain.ThenFunc(s.handleAPIPerformance))
mux.Handle("GET /api/version", apiChain.ThenFunc(s.handleAPIVersion))
mux.Handle("GET /api/captures/{id}", apiChain.ThenFunc(s.handleAPICapture))
s.mux = mux
s.handler = chain.New(CreateRequestLogMiddleware(s.proxylog), CreateCORSMiddleware()).Then(mux)
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handler.ServeHTTP(w, r)
}
// CloseStreams cancels long-lived response streams (Server-Sent Events) so a
// graceful httpServer.Shutdown can drain without blocking on them. It does not
// tear down routers; call Shutdown for that. Safe to call repeatedly.
func (s *Server) CloseStreams() {
s.shutdownFn()
}
// Shutdown stops the local and peer routers in parallel. It is idempotent;
// repeated calls return nil without re-running shutdown.
//
// Callers must drain inflight HTTP requests (httpServer.Shutdown) before
// calling this, otherwise inflight requests 502 when their processes are torn
// down. Call CloseStreams before httpServer.Shutdown so SSE streams do not
// block the drain.
func (s *Server) Shutdown(timeout time.Duration) error {
if !s.shuttingDown.CompareAndSwap(false, true) {
return nil
}
s.shutdownFn()
var wg sync.WaitGroup
var mu sync.Mutex
var errs []error
for _, rt := range []router.Router{s.local, s.peer} {
if rt == nil {
continue
}
wg.Add(1)
go func(rt router.Router) {
defer wg.Done()
if err := rt.Shutdown(timeout); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}(rt)
}
wg.Wait()
return errors.Join(errs...)
}
+331
View File
@@ -0,0 +1,331 @@
package server
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// stubRouter is a minimal router.LocalRouter for Server dispatch tests.
type stubRouter struct {
models map[string]bool
response string
shutdownCalls atomic.Int32
running map[string]process.ProcessState
unloadCalls atomic.Int32
loggers map[string]*logmon.Monitor
}
func newStubRouter(models []string, response string) *stubRouter {
m := make(map[string]bool, len(models))
for _, id := range models {
m[id] = true
}
return &stubRouter{models: m, response: response}
}
func (s *stubRouter) Handles(model string) bool { return s.models[model] }
func (s *stubRouter) Shutdown(_ time.Duration) error { s.shutdownCalls.Add(1); return nil }
func (s *stubRouter) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(s.response))
}
func (s *stubRouter) RunningModels() map[string]process.ProcessState { return s.running }
func (s *stubRouter) Unload(_ time.Duration, _ ...string) { s.unloadCalls.Add(1) }
func (s *stubRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
if s.loggers != nil {
if lg, ok := s.loggers[modelID]; ok {
return lg, true
}
}
return nil, false
}
// newTestServer wires a Server with stub routers and a built mux.
func newTestServer(local router.LocalRouter, peer router.Router) *Server {
ctx, cancel := context.WithCancel(context.Background())
proxylog := logmon.NewWriter(io.Discard)
s := &Server{
cfg: config.Config{},
muxlog: logmon.NewWriter(io.Discard),
proxylog: proxylog,
upstreamlog: logmon.NewWriter(io.Discard),
inflight: &inflightCounter{},
metrics: newMetricsMonitor(proxylog, 0, 0),
local: local,
peer: peer,
shutdownCtx: ctx,
shutdownFn: cancel,
}
s.routes()
return s
}
func chatRequest(model string) *http.Request {
body := strings.NewReader(`{"model":"` + model + `"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
return req
}
func TestServer_New_GroupConfig(t *testing.T) {
discard := logmon.NewWriter(io.Discard)
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
if err != nil {
t.Fatalf("New (group): %v", err)
}
if err := s.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
}
func TestServer_New_MatrixConfig(t *testing.T) {
discard := logmon.NewWriter(io.Discard)
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
if err != nil {
t.Fatalf("New (matrix): %v", err)
}
if err := s.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
}
func TestServer_RouteToLocalModel(t *testing.T) {
s := newTestServer(
newStubRouter([]string{"local-model"}, "local response"),
newStubRouter(nil, ""),
)
w := httptest.NewRecorder()
s.ServeHTTP(w, chatRequest("local-model"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if w.Body.String() != "local response" {
t.Errorf("body=%q want %q", w.Body.String(), "local response")
}
}
func TestServer_RouteToPeerModel(t *testing.T) {
s := newTestServer(
newStubRouter(nil, ""),
newStubRouter([]string{"peer-model"}, "peer response"),
)
w := httptest.NewRecorder()
s.ServeHTTP(w, chatRequest("peer-model"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if w.Body.String() != "peer response" {
t.Errorf("body=%q want %q", w.Body.String(), "peer response")
}
}
func TestServer_UnknownModelReturns404(t *testing.T) {
s := newTestServer(
newStubRouter([]string{"local-model"}, ""),
newStubRouter(nil, ""),
)
w := httptest.NewRecorder()
s.ServeHTTP(w, chatRequest("unknown-model"))
if w.Code != http.StatusNotFound {
t.Errorf("status=%d want 404 body=%q", w.Code, w.Body.String())
}
}
func TestServer_UnknownPathReturns404(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/does-not-exist", nil))
if w.Code != http.StatusNotFound {
t.Errorf("status=%d want 404", w.Code)
}
}
func TestServer_Health(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
for _, path := range []string{"/health", "/wol-health"} {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
if w.Code != http.StatusOK || w.Body.String() != "OK" {
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
}
}
}
func TestServer_CORSPreflight(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
w := httptest.NewRecorder()
s.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Fatalf("status=%d want 204", w.Code)
}
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
t.Errorf("Access-Control-Allow-Origin=%q want *", got)
}
}
func TestServer_Unload(t *testing.T) {
local := newStubRouter([]string{"m1"}, "")
s := newTestServer(local, newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/unload", nil))
if w.Code != http.StatusOK || w.Body.String() != "OK" {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := local.unloadCalls.Load(); got != 1 {
t.Errorf("unloadCalls=%d want 1", got)
}
}
func TestServer_Running(t *testing.T) {
local := newStubRouter([]string{"m1"}, "")
local.running = map[string]process.ProcessState{"m1": process.StateReady}
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{Models: map[string]config.ModelConfig{
"m1": {
Cmd: "llama-server",
Proxy: "http://localhost:9999",
UnloadAfter: 300,
Name: "Model One",
Description: "the first model",
},
}}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/running", nil))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
var resp struct {
Running []runningModel `json:"running"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode: %v body=%q", err, w.Body.String())
}
if len(resp.Running) != 1 {
t.Fatalf("running=%v want 1 entry", resp.Running)
}
want := runningModel{
Model: "m1",
State: "ready",
Cmd: "llama-server",
Proxy: "http://localhost:9999",
TTL: 300,
Name: "Model One",
Description: "the first model",
}
if resp.Running[0] != want {
t.Errorf("got %+v want %+v", resp.Running[0], want)
}
}
func TestServer_Preload(t *testing.T) {
local := newStubRouter([]string{"m1"}, "ok")
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{Hooks: config.HooksConfig{
OnStartup: config.HookOnStartup{Preload: []string{"m1"}},
}}
got := make(chan shared.ModelPreloadedEvent, 1)
cancel := event.On(func(e shared.ModelPreloadedEvent) { got <- e })
defer cancel()
s.startPreload()
select {
case e := <-got:
if e.ModelName != "m1" || !e.Success {
t.Errorf("event=%+v want {ModelName:m1 Success:true}", e)
}
case <-time.After(2 * time.Second):
t.Fatal("preload event not received")
}
}
func TestServer_Shutdown_StopsRoutersAndIsIdempotent(t *testing.T) {
local := newStubRouter([]string{"local-model"}, "")
peer := newStubRouter(nil, "")
s := newTestServer(local, peer)
if err := s.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
if err := s.Shutdown(time.Second); err != nil {
t.Fatalf("second Shutdown: %v", err)
}
if got := local.shutdownCalls.Load(); got != 1 {
t.Errorf("local shutdownCalls=%d want 1", got)
}
if got := peer.shutdownCalls.Load(); got != 1 {
t.Errorf("peer shutdownCalls=%d want 1", got)
}
}
func TestServer_LogStream_ModelID(t *testing.T) {
buf := logmon.NewWriter(io.Discard)
buf.Write([]byte("hello from model"))
local := newStubRouter([]string{"mymodel"}, "")
local.loggers = map[string]*logmon.Monitor{"mymodel": buf}
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{Models: map[string]config.ModelConfig{"mymodel": {}}}
// Pre-cancel the context so the streaming loop exits immediately after
// flushing history.
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := httptest.NewRequest(http.MethodGet, "/logs/stream/mymodel", nil).WithContext(ctx)
w := httptest.NewRecorder()
s.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := w.Body.String(); got != "hello from model" {
t.Errorf("body=%q want %q", got, "hello from model")
}
}
func TestServer_LogStream_UnknownID_Returns400(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs/stream/no-such-model", nil))
if w.Code != http.StatusBadRequest {
t.Errorf("status=%d want 400", w.Code)
}
}
+111
View File
@@ -0,0 +1,111 @@
package server
import (
"embed"
"io/fs"
"net/http"
"path"
"strings"
)
// uiStaticFS holds the embedded UI build. The build is copied into ui_dist by
// the Makefile's `ui` target; placeholder.txt keeps the embed valid before a
// build has run.
//
//go:embed ui_dist
var uiStaticFS embed.FS
// uiFS is the embedded UI rooted at ui_dist.
var uiFS = func() http.FileSystem {
sub, err := fs.Sub(uiStaticFS, "ui_dist")
if err != nil {
panic(err)
}
return http.FS(sub)
}()
// selectEncoding chooses the best pre-compressed encoding the client accepts.
// It returns the encoding ("br" or "gzip") and the matching file extension.
func selectEncoding(acceptEncoding string) (encoding, ext string) {
if acceptEncoding == "" {
return "", ""
}
for _, part := range strings.Split(acceptEncoding, ",") {
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "br" {
return "br", ".br"
}
}
for _, part := range strings.Split(acceptEncoding, ",") {
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "gzip" {
return "gzip", ".gz"
}
}
return "", ""
}
// serveCompressedFile serves name from fsys, preferring a pre-compressed
// sibling (name+".br" / name+".gz") when the client accepts it. It returns an
// error without writing a response when name cannot be served, so callers can
// fall back (e.g. SPA routing).
func serveCompressedFile(fsys http.FileSystem, w http.ResponseWriter, r *http.Request, name string) error {
if encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding")); encoding != "" {
if cf, err := fsys.Open(name + ext); err == nil {
defer cf.Close()
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
w.Header().Set("Content-Encoding", encoding)
w.Header().Add("Vary", "Accept-Encoding")
http.ServeContent(w, r, name, stat.ModTime(), cf)
return nil
}
}
}
file, err := fsys.Open(name)
if err != nil {
return err
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return err
}
if stat.IsDir() {
return fs.ErrNotExist
}
http.ServeContent(w, r, name, stat.ModTime(), file)
return nil
}
// handleUI serves the embedded SPA under /ui/.
func (s *Server) handleUI(w http.ResponseWriter, r *http.Request) {
serveUI(uiFS, w, r)
}
// serveUI serves the SPA from fsys. Real files are served with compression
// support; unknown paths without a file extension fall back to index.html so
// client-side routing works.
func serveUI(fsys http.FileSystem, w http.ResponseWriter, r *http.Request) {
name := strings.TrimPrefix(r.URL.Path, "/ui/")
if name == "" {
name = "index.html"
}
if err := serveCompressedFile(fsys, w, r, name); err != nil {
if strings.Contains(path.Base(name), ".") {
http.NotFound(w, r)
return
}
if err := serveCompressedFile(fsys, w, r, "index.html"); err != nil {
http.NotFound(w, r)
}
}
}
// handleFavicon serves /favicon.ico from the embedded UI build.
func (s *Server) handleFavicon(w http.ResponseWriter, r *http.Request) {
if err := serveCompressedFile(uiFS, w, r, "favicon.ico"); err != nil {
http.NotFound(w, r)
}
}
+1
View File
@@ -0,0 +1 @@
placeholder so //go:embed ui_dist succeeds before the UI is built
+92
View File
@@ -0,0 +1,92 @@
package server
import (
"net/http"
"net/http/httptest"
"testing"
"testing/fstest"
)
func TestServer_SelectEncoding(t *testing.T) {
cases := []struct {
accept string
encoding string
ext string
}{
{"", "", ""},
{"gzip", "gzip", ".gz"},
{"gzip, deflate, br", "br", ".br"},
{"deflate", "", ""},
{"br;q=1.0, gzip;q=0.8", "br", ".br"},
}
for _, c := range cases {
enc, ext := selectEncoding(c.accept)
if enc != c.encoding || ext != c.ext {
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)", c.accept, enc, ext, c.encoding, c.ext)
}
}
}
func uiTestFS() http.FileSystem {
return http.FS(fstest.MapFS{
"index.html": {Data: []byte("<html>app</html>")},
"app.js": {Data: []byte("plain")},
"app.js.br": {Data: []byte("brotli")},
"app.js.gz": {Data: []byte("gzipped")},
"favicon.ico": {Data: []byte("icon")},
})
}
func serveUIRequest(t *testing.T, path, acceptEncoding string) *httptest.ResponseRecorder {
t.Helper()
req := httptest.NewRequest(http.MethodGet, path, nil)
if acceptEncoding != "" {
req.Header.Set("Accept-Encoding", acceptEncoding)
}
w := httptest.NewRecorder()
serveUI(uiTestFS(), w, req)
return w
}
func TestServer_ServeUI_File(t *testing.T) {
w := serveUIRequest(t, "/ui/app.js", "")
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", w.Code)
}
if w.Body.String() != "plain" {
t.Errorf("body = %q, want plain", w.Body.String())
}
}
func TestServer_ServeUI_Brotli(t *testing.T) {
w := serveUIRequest(t, "/ui/app.js", "gzip, br")
if got := w.Header().Get("Content-Encoding"); got != "br" {
t.Fatalf("Content-Encoding = %q, want br", got)
}
if w.Body.String() != "brotli" {
t.Errorf("body = %q, want brotli", w.Body.String())
}
}
func TestServer_ServeUI_IndexAndRoot(t *testing.T) {
for _, path := range []string{"/ui/", "/ui/index.html"} {
w := serveUIRequest(t, path, "")
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
}
}
}
func TestServer_ServeUI_SPAFallback(t *testing.T) {
w := serveUIRequest(t, "/ui/models", "")
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
t.Errorf("SPA fallback: status=%d body=%q", w.Code, w.Body.String())
}
}
func TestServer_ServeUI_MissingFile(t *testing.T) {
w := serveUIRequest(t, "/ui/missing.js", "")
if w.Code != http.StatusNotFound {
t.Errorf("status = %d, want 404", w.Code)
}
}