Introduce new routing backend (#790)
This is a huge backend change that essentially started with rewriting the concurrency handling for processes and blew up to a refactor of the entire application. In short these are the improvements: **Better state and life cycle management:** Life cycle management of processes has always been the trickiest part of the code. Juggling mutex locks between multiple locations to reduce race conditions was complex. Too complex for my feeble brain to build a simple mental model around as llama-swap gained more features. All of that has been refactored. Most of the locks are gone, replaced with a single run() that owns all state changes. There is one place to start from now to understand and extend routing logic. The improved life cycle management makes it easier to implement more complex swap optimization strategies in the future like #727. **Collation of requests:** llama-swap previously handled requests and swapping in the order they came in. For example requests for models in this order ABCABC would result in 5 swaps. Now those requests are handled in this order AABBCC. The result is less time waiting for swap under a high churn request queue. This fixes #588 #612. A possible future enhancement is to support a starvation parameter so swap can be forced when models have been waiting too long. **Shared base implementation for groups and swap matrix:** During the refactor it became clear that much of the swapping logic was shared between these two implementations. That is not surprising considering the swap matrix was added many moons after groups. Now they share a common base and their specific swap strategies are implemented into the swapPlanner interface. Requests for bespoke or specific swapping scenarios is a common theme in the issues. Now users can implement whatever bespoke and weird swapping strategy they want in their own fork. Just ask your agent of choice to implement swapPlanner. I'll still remaining more conservative on what actually lands in core llama-swap and will continue to evaluate PRs if the changes is good for everyone or just one specific use case. **AI / Agentic Disclosure:** I paid very close attention to the low level swap concurrency design and implementation. It's important to keep that essential part reliable, boring and no surprises. Backwards compatibility was also maintained, even the one way non-exclusive group model loading behaviour that people have rightly pointed out be a weird design decision. With the underlying swap core done the web server, api and UI sitting on top were largely ported over with Claude Code and Opus 4.7 in multiple phases. If you're curious I kept the changes in docs/newrouter-todo.md. I did several passes to make sure things weren't left behind. However, even frontier LLMs at the time of this PR still make small decisions that don't make a lot of sense. They get shit wrong all the time, just in small subtle way. That said, there's likely to be some new bugs introduced with this massive refactor. I'm fairly confident that there's no major architectural flaws that would cause goal seeking agents to make dumb, ugly code decisions. For a little while the legacy llama-swap will be available under cmd/legacy/llama-swap. The plan is to eventually delete that entry point as well as the proxy package. On a bit of a personal note, this PR is exciting and a bit sad for me. I hand wrote much of the original code and this PR ultimately replaces much of it. While the old code served as a good reference for the agent to implement the new stuff it still a bit sad to eventually delete it all.
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
|
||||
type modelRecord struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// handleListModels serves the OpenAI-compatible model listing: local models
|
||||
// (with optional aliases) plus peer models.
|
||||
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
created := time.Now().Unix()
|
||||
data := make([]modelRecord, 0, len(s.cfg.Models))
|
||||
|
||||
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
|
||||
rec := modelRecord{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
Created: created,
|
||||
OwnedBy: "llama-swap",
|
||||
Name: strings.TrimSpace(name),
|
||||
Description: strings.TrimSpace(description),
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
rec.Meta = map[string]any{"llamaswap": metadata}
|
||||
}
|
||||
return rec
|
||||
}
|
||||
|
||||
for id, mc := range s.cfg.Models {
|
||||
if mc.Unlisted {
|
||||
continue
|
||||
}
|
||||
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
|
||||
|
||||
if s.cfg.IncludeAliasesInList {
|
||||
for _, alias := range mc.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID })
|
||||
|
||||
// Echo the Origin so browser clients can read the listing.
|
||||
if origin := r.Header.Get("Origin"); origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// runningModel is one entry in the /running listing.
|
||||
type runningModel struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// handleUnload stops every running local process. Peer models are remote and
|
||||
// unaffected.
|
||||
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
// handleRunning lists local processes that are not stopped, joining each model
|
||||
// ID against its config for the cmd/proxy/ttl/name/description metadata.
|
||||
func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) {
|
||||
states := s.local.RunningModels()
|
||||
list := make([]runningModel, 0, len(states))
|
||||
for id, state := range states {
|
||||
mc := s.cfg.Models[id]
|
||||
list = append(list, runningModel{
|
||||
Model: id,
|
||||
State: string(state),
|
||||
Cmd: mc.Cmd,
|
||||
Proxy: mc.Proxy,
|
||||
TTL: mc.UnloadAfter,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
})
|
||||
}
|
||||
sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model })
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{"running": list})
|
||||
}
|
||||
|
||||
// discardResponseWriter satisfies http.ResponseWriter for preload requests,
|
||||
// dropping the body while capturing the status code.
|
||||
type discardResponseWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (d *discardResponseWriter) Header() http.Header {
|
||||
if d.header == nil {
|
||||
d.header = make(http.Header)
|
||||
}
|
||||
return d.header
|
||||
}
|
||||
|
||||
func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
func (d *discardResponseWriter) WriteHeader(status int) { d.status = status }
|
||||
|
||||
// startPreload fires a background GET / at every model named in
|
||||
// Hooks.OnStartup.Preload so they are warm before the first real request.
|
||||
// Preload names are already resolved to real model IDs by config loading.
|
||||
func (s *Server) startPreload() {
|
||||
models := s.cfg.Hooks.OnStartup.Preload
|
||||
if len(models) == 0 {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for _, modelID := range models {
|
||||
if !s.local.Handles(modelID) {
|
||||
s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID)
|
||||
continue
|
||||
}
|
||||
s.proxylog.Infof("preloading model: %s", modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
|
||||
|
||||
dw := &discardResponseWriter{status: http.StatusOK}
|
||||
s.local.ServeHTTP(dw, req)
|
||||
|
||||
success := dw.status < http.StatusBadRequest
|
||||
if !success {
|
||||
s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status)
|
||||
}
|
||||
event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleMetrics serves Prometheus-format performance metrics. Returns 503 when
|
||||
// performance monitoring is disabled.
|
||||
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte("# performance monitor not available\n"))
|
||||
return
|
||||
}
|
||||
s.perf.MetricsHandler().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func handleRootRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui", http.StatusFound)
|
||||
}
|
||||
|
||||
func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ui/models", http.StatusFound)
|
||||
}
|
||||
|
||||
// handleUpstream proxies ANY request under /upstream/<model>/<path> directly to
|
||||
// the model's process, bypassing model dispatch by body/query inspection.
|
||||
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamPath := r.PathValue("upstreamPath")
|
||||
|
||||
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect /upstream/model to /upstream/model/ so relative URLs in upstream
|
||||
// responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method.
|
||||
if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") {
|
||||
newPath := "/upstream/" + searchName + "/"
|
||||
if r.URL.RawQuery != "" {
|
||||
newPath += "?" + r.URL.RawQuery
|
||||
}
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
||||
http.Redirect(w, r, newPath, http.StatusMovedPermanently)
|
||||
} else {
|
||||
http.Redirect(w, r, newPath, http.StatusPermanentRedirect)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Strip the /upstream/<model> prefix before forwarding.
|
||||
r.URL.Path = remainingPath
|
||||
// Pin the resolved model so the router skips body/query extraction.
|
||||
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
|
||||
|
||||
switch {
|
||||
case s.local.Handles(modelID):
|
||||
s.local.ServeHTTP(w, r)
|
||||
case s.peer.Handles(modelID):
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// findModelInPath walks a slash-separated path, building up segments until one
|
||||
// matches a configured model. This resolves model names that contain slashes
|
||||
// (e.g. "author/model"). Returns the matched name, its real model ID, the
|
||||
// remaining path, and whether a match was found.
|
||||
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
name := ""
|
||||
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = part
|
||||
} else {
|
||||
name = name + "/" + part
|
||||
}
|
||||
|
||||
if modelID, ok := cfg.RealModelName(name); ok {
|
||||
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", "", false
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestServer_HandleListModels(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"visible": {Name: "Visible", Description: "a model"},
|
||||
"hidden": {Unlisted: true},
|
||||
},
|
||||
Peers: config.PeerDictionaryConfig{
|
||||
"peer1": {Models: []string{"remote-model"}},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Errorf("Access-Control-Allow-Origin = %q", got)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Data []modelRecord `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
ids := map[string]bool{}
|
||||
for _, m := range resp.Data {
|
||||
ids[m.ID] = true
|
||||
}
|
||||
if !ids["visible"] || !ids["remote-model"] {
|
||||
t.Errorf("missing expected models: %v", ids)
|
||||
}
|
||||
if ids["hidden"] {
|
||||
t.Error("unlisted model should not appear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleListModels_Aliases(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
IncludeAliasesInList: true,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"real": {Aliases: []string{"nick"}},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
|
||||
|
||||
var resp struct {
|
||||
Data []modelRecord `json:"data"`
|
||||
}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
ids := map[string]bool{}
|
||||
for _, m := range resp.Data {
|
||||
ids[m.ID] = true
|
||||
}
|
||||
if !ids["real"] || !ids["nick"] {
|
||||
t.Errorf("expected alias entry; got %v", ids)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_FindModelInPath(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||
"author/model": {},
|
||||
"simple": {},
|
||||
}}
|
||||
|
||||
cases := []struct {
|
||||
path string
|
||||
wantName string
|
||||
wantRem string
|
||||
wantFound bool
|
||||
}{
|
||||
{"/simple/v1/chat", "simple", "/v1/chat", true},
|
||||
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
|
||||
{"/author/model", "author/model", "/", true},
|
||||
{"/missing/v1", "", "", false},
|
||||
{"/", "", "", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
name, _, rem, found := findModelInPath(cfg, c.path)
|
||||
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
|
||||
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
|
||||
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
|
||||
t.Run("proxies to local", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat", nil))
|
||||
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||
t.Errorf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redirects bare model path", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1", nil))
|
||||
if w.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("status = %d, want 301", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown model 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/nope/v1", nil))
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/metrics", nil))
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Redirects(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
for path, want := range map[string]string{"/": "/ui", "/upstream": "/ui/models"} {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("%s: status = %d, want 302", path, w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Location"); got != want {
|
||||
t.Errorf("%s: Location = %q, want %q", path, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// apiModel is one entry in the /api/events modelStatus payload.
|
||||
type apiModel struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
// modelStatus returns every configured model joined with its current process
|
||||
// state (defaulting to "stopped"), followed by peer models.
|
||||
func (s *Server) modelStatus() []apiModel {
|
||||
running := s.local.RunningModels()
|
||||
|
||||
ids := make([]string, 0, len(s.cfg.Models))
|
||||
for id := range s.cfg.Models {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
sort.Strings(ids)
|
||||
|
||||
models := make([]apiModel, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
mc := s.cfg.Models[id]
|
||||
state := "stopped"
|
||||
if st, ok := running[id]; ok {
|
||||
state = string(st)
|
||||
}
|
||||
models = append(models, apiModel{
|
||||
Id: id,
|
||||
Name: mc.Name,
|
||||
Description: mc.Description,
|
||||
State: state,
|
||||
Unlisted: mc.Unlisted,
|
||||
Aliases: mc.Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
for peerID, peer := range s.cfg.Peers {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, apiModel{Id: modelID, PeerID: peerID})
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// handleAPIUnloadAll stops every running local process.
|
||||
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
|
||||
s.local.Unload(0)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
|
||||
}
|
||||
|
||||
// handleAPIUnloadModel stops a single named local process.
|
||||
func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
|
||||
requested := strings.TrimPrefix(r.PathValue("model"), "/")
|
||||
realName, found := s.cfg.RealModelName(requested)
|
||||
if !found {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "model not found")
|
||||
return
|
||||
}
|
||||
if !s.local.Handles(realName) {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
|
||||
return
|
||||
}
|
||||
s.local.Unload(0, realName)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
// handleAPIMetrics serves the activity log as a JSON array.
|
||||
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := s.metrics.getMetricsJSON()
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// handleAPIPerformance serves the buffered system/GPU stats, optionally
|
||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||
return
|
||||
}
|
||||
|
||||
sysStats, gpuStats := s.perf.Current()
|
||||
|
||||
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
|
||||
after, err := time.Parse(time.RFC3339, afterStr)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
|
||||
return
|
||||
}
|
||||
filteredSys := make([]perf.SysStat, 0, len(sysStats))
|
||||
for _, st := range sysStats {
|
||||
if st.Timestamp.After(after) {
|
||||
filteredSys = append(filteredSys, st)
|
||||
}
|
||||
}
|
||||
sysStats = filteredSys
|
||||
|
||||
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
|
||||
for _, g := range gpuStats {
|
||||
if g.Timestamp.After(after) {
|
||||
filteredGpu = append(filteredGpu, g)
|
||||
}
|
||||
}
|
||||
gpuStats = filteredGpu
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"sys_stats": sysStats,
|
||||
"gpu_stats": gpuStats,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAPIVersion serves the build metadata.
|
||||
func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"version": s.build.Version,
|
||||
"commit": s.build.Commit,
|
||||
"build_date": s.build.Date,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAPICapture returns the stored request/response capture for a metric ID.
|
||||
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
|
||||
return
|
||||
}
|
||||
|
||||
capture := s.metrics.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
|
||||
return
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(capture)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
}
|
||||
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// handleAPIEvents streams server events (model status, log data, metrics,
|
||||
// in-flight counts) to the client as Server-Sent Events.
|
||||
func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering SSE
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
// internal/event already has a 50K event buffer
|
||||
// a 1K message buffer should be enough, watch the logs for the warning that the sendBuffer is full
|
||||
sendBuffer := make(chan messageEnvelope, 1024)
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
send := func(msg messageEnvelope) {
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
s.proxylog.Warn("handleAPIEvents send suppressed due to context done")
|
||||
default:
|
||||
s.proxylog.Warn("handleAPIEvents sendBuffer full, dropped message")
|
||||
}
|
||||
}
|
||||
sendModels := func() {
|
||||
if data, err := json.Marshal(s.modelStatus()); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeModelStatus, Data: string(data)})
|
||||
}
|
||||
}
|
||||
sendLogData := func(source string, data []byte) {
|
||||
if j, err := json.Marshal(map[string]string{"source": source, "data": string(data)}); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeLogData, Data: string(j)})
|
||||
}
|
||||
}
|
||||
sendMetrics := func(metrics []ActivityLogEntry) {
|
||||
if j, err := json.Marshal(metrics); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeMetrics, Data: string(j)})
|
||||
}
|
||||
}
|
||||
sendInFlight := func(total int) {
|
||||
if j, err := json.Marshal(map[string]int{"total": total}); err == nil {
|
||||
send(messageEnvelope{Type: msgTypeInFlight, Data: string(j)})
|
||||
}
|
||||
}
|
||||
|
||||
defer event.On(func(e shared.ProcessStateChangeEvent) { sendModels() })()
|
||||
defer event.On(func(e shared.ConfigFileChangedEvent) { sendModels() })()
|
||||
defer s.proxylog.OnLogData(func(data []byte) { sendLogData("proxy", data) })()
|
||||
defer s.upstreamlog.OnLogData(func(data []byte) { sendLogData("upstream", data) })()
|
||||
defer event.On(func(e ActivityLogEvent) { sendMetrics([]ActivityLogEntry{e.Metrics}) })()
|
||||
defer event.On(func(e shared.InFlightRequestsEvent) { sendInFlight(e.Total) })()
|
||||
|
||||
// initial payload
|
||||
sendLogData("proxy", s.proxylog.GetHistory())
|
||||
sendLogData("upstream", s.upstreamlog.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(s.metrics.getMetrics())
|
||||
sendInFlight(int(s.inflight.Current()))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-s.shutdownCtx.Done():
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(w, "event:message\ndata:%s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServer_InflightMiddleware(t *testing.T) {
|
||||
c := &inflightCounter{}
|
||||
mw := CreateInflightMiddleware(c)
|
||||
|
||||
var duringRequest int64
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
duringRequest = c.Current()
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil))
|
||||
|
||||
if duringRequest != 1 {
|
||||
t.Errorf("counter during request = %d, want 1", duringRequest)
|
||||
}
|
||||
if got := c.Current(); got != 0 {
|
||||
t.Errorf("counter after request = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIVersion(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.build = BuildInfo{Version: "1.2.3", Commit: "deadbeef", Date: "2026-05-19"}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/version", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got["version"] != "1.2.3" || got["commit"] != "deadbeef" || got["build_date"] != "2026-05-19" {
|
||||
t.Errorf("body = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIMetrics_Empty(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/metrics", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if body := strings.TrimSpace(w.Body.String()); body != "[]" {
|
||||
t.Errorf("body = %q, want []", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIPerformance_Unavailable(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/performance", nil))
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_APIEvents_InitialPayload(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/events", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.ServeHTTP(w, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handler did not return after context cancel")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
for _, want := range []string{`"type":"modelStatus"`, `"type":"inflight"`, `"type":"logData"`} {
|
||||
if !strings.Contains(body, want) {
|
||||
t.Errorf("initial SSE payload missing %s; body=%q", want, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// CreateAuthMiddleware returns middleware that validates API keys when the
|
||||
// config declares any. It accepts the key via Authorization: Bearer,
|
||||
// Authorization: Basic (password field), or x-api-key. On success the auth
|
||||
// headers are stripped so they never leak to upstream. When no keys are
|
||||
// configured the middleware is a pass-through.
|
||||
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
|
||||
keys := cfg.RequiredAPIKeys
|
||||
return func(next http.Handler) http.Handler {
|
||||
if len(keys) == 0 {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provided := extractAPIKey(r)
|
||||
|
||||
valid := false
|
||||
for _, key := range keys {
|
||||
if provided == key {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Del("Authorization")
|
||||
r.Header.Del("x-api-key")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
|
||||
// then Bearer, then x-api-key.
|
||||
func extractAPIKey(r *http.Request) string {
|
||||
var bearerKey, basicKey string
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
|
||||
basicKey = parts[1] // password field is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case basicKey != "":
|
||||
return basicKey
|
||||
case bearerKey != "":
|
||||
return bearerKey
|
||||
default:
|
||||
return r.Header.Get("x-api-key")
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCORSMiddleware returns middleware that answers OPTIONS preflight
|
||||
// requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS
|
||||
// requests pass through untouched.
|
||||
func CreateCORSMiddleware() chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodOptions {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers))
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With")
|
||||
}
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sanitizeAccessControlRequestHeaderValues drops any header names that contain
|
||||
// characters outside the HTTP token grammar before echoing them back.
|
||||
func sanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
)
|
||||
|
||||
func TestServer_ExtractAPIKey(t *testing.T) {
|
||||
basicHeader := func(user, pass string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
auth string
|
||||
xapi string
|
||||
want string
|
||||
}{
|
||||
{"none", "", "", ""},
|
||||
{"bearer", "Bearer tok123", "", "tok123"},
|
||||
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
|
||||
{"x-api-key", "", "xkey", "xkey"},
|
||||
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
|
||||
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
|
||||
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if c.auth != "" {
|
||||
r.Header.Set("Authorization", c.auth)
|
||||
}
|
||||
if c.xapi != "" {
|
||||
r.Header.Set("x-api-key", c.xapi)
|
||||
}
|
||||
if got := extractAPIKey(r); got != c.want {
|
||||
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"Content-Type, Authorization", "Content-Type, Authorization"},
|
||||
{" X-Custom , Accept ", "X-Custom, Accept"},
|
||||
{"Valid, Bad Header", "Valid"},
|
||||
{"Bad@Header", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want {
|
||||
t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_IsTokenChar(t *testing.T) {
|
||||
for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" {
|
||||
if !isTokenChar(r) {
|
||||
t.Errorf("isTokenChar(%q) = false, want true", r)
|
||||
}
|
||||
}
|
||||
for _, r := range " @()/\t\"" {
|
||||
if isTokenChar(r) {
|
||||
t.Errorf("isTokenChar(%q) = true, want false", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_AuthMiddleware(t *testing.T) {
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
|
||||
t.Error("auth headers leaked to upstream")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("no keys configured passes through", func(t *testing.T) {
|
||||
mw := CreateAuthMiddleware(config.Config{})
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
cfg := config.Config{RequiredAPIKeys: []string{"secret"}}
|
||||
|
||||
t.Run("valid key", func(t *testing.T) {
|
||||
mw := CreateAuthMiddleware(cfg)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Authorization", "Bearer secret")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid key", func(t *testing.T) {
|
||||
mw := CreateAuthMiddleware(cfg)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Authorization", "Bearer wrong")
|
||||
w := httptest.NewRecorder()
|
||||
mw(final).ServeHTTP(w, r)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
if w.Header().Get("WWW-Authenticate") == "" {
|
||||
t.Error("missing WWW-Authenticate header")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// ReqRespCapture is a stored request/response pair for a single metered request.
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// captureFields is a bitmask controlling what a route stores in a ReqRespCapture.
|
||||
type captureFields uint
|
||||
|
||||
const (
|
||||
captureReqHeaders captureFields = 1 << iota
|
||||
captureReqBody
|
||||
captureRespHeaders
|
||||
captureRespBody
|
||||
)
|
||||
|
||||
const (
|
||||
captureReqAll = captureReqHeaders | captureReqBody
|
||||
captureRespAll = captureRespHeaders | captureRespBody
|
||||
captureAll = captureReqAll | captureRespAll
|
||||
)
|
||||
|
||||
// captureFieldsByPath overrides the default capture mask for routes carrying
|
||||
// large binary payloads (audio/image) where storing the full body is wasteful.
|
||||
var captureFieldsByPath = map[string]captureFields{
|
||||
"/v1/audio/speech": captureReqAll | captureRespHeaders,
|
||||
"/v1/audio/voices": captureReqHeaders | captureRespAll,
|
||||
"/v1/audio/transcriptions": captureReqHeaders | captureRespHeaders | captureRespBody,
|
||||
"/v1/images/generations": captureReqAll | captureRespHeaders,
|
||||
"/v1/images/edits": captureReqHeaders | captureRespHeaders,
|
||||
"/sdapi/v1/txt2img": captureReqAll | captureRespHeaders,
|
||||
"/sdapi/v1/img2img": captureReqHeaders | captureRespHeaders,
|
||||
}
|
||||
|
||||
// captureFieldsFor returns the capture mask for a request path. Unlisted routes
|
||||
// (the OpenAI-compatible JSON endpoints) capture everything.
|
||||
func captureFieldsFor(path string) captureFields {
|
||||
if cf, ok := captureFieldsByPath[path]; ok {
|
||||
return cf
|
||||
}
|
||||
return captureAll
|
||||
}
|
||||
|
||||
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||
var zstdEncOptions = []zstd.EOption{
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
}
|
||||
|
||||
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||
var zstdEncPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||
return enc
|
||||
},
|
||||
}
|
||||
|
||||
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||
var zstdDecPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
dec, _ := zstd.NewReader(nil)
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
|
||||
// Returns the compressed bytes and the original CBOR byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
cborBytes, err := cbor.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
zenc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(zenc)
|
||||
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed CBOR into a ReqRespCapture.
|
||||
func decompressCapture(data []byte) (*ReqRespCapture, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
cborBytes, err := dec.DecodeAll(data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress capture: %w", err)
|
||||
}
|
||||
var capture ReqRespCapture
|
||||
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capture: %w", err)
|
||||
}
|
||||
return &capture, nil
|
||||
}
|
||||
|
||||
// addCapture compresses and stores a capture in the cache. Returns true if the
|
||||
// capture was stored.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
|
||||
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
|
||||
return false
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
return true
|
||||
}
|
||||
|
||||
// getCaptureByID decompresses and unmarshals a capture by ID. Returns nil if
|
||||
// the capture is not found or decompression fails.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
if mp.captureCache == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := mp.captureCache.Get(id)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
capture, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
return capture
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that are redacted in captures.
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// headerMap flattens an http.Header to a single-value map.
|
||||
func headerMap(h http.Header) map[string]string {
|
||||
m := make(map[string]string, len(h))
|
||||
for key, values := range h {
|
||||
if len(values) > 0 {
|
||||
m[key] = values[0]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]".
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestServer_CaptureCompressRoundtrip(t *testing.T) {
|
||||
orig := &ReqRespCapture{
|
||||
ID: 7,
|
||||
ReqPath: "/v1/chat/completions",
|
||||
ReqHeaders: map[string]string{"Content-Type": "application/json"},
|
||||
ReqBody: []byte(`{"model":"m"}`),
|
||||
RespHeaders: map[string]string{"Content-Type": "application/json"},
|
||||
RespBody: []byte(`{"usage":{}}`),
|
||||
}
|
||||
|
||||
compressed, uncompressed, err := compressCapture(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("compressCapture: %v", err)
|
||||
}
|
||||
if uncompressed == 0 || len(compressed) == 0 {
|
||||
t.Fatalf("unexpected sizes: uncompressed=%d compressed=%d", uncompressed, len(compressed))
|
||||
}
|
||||
|
||||
got, err := decompressCapture(compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("decompressCapture: %v", err)
|
||||
}
|
||||
if got.ID != orig.ID || got.ReqPath != orig.ReqPath ||
|
||||
!bytes.Equal(got.ReqBody, orig.ReqBody) || !bytes.Equal(got.RespBody, orig.RespBody) {
|
||||
t.Fatalf("roundtrip mismatch: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CaptureStoreAndRetrieve(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||
if !mm.enableCaptures {
|
||||
t.Fatal("captures should be enabled with non-zero buffer")
|
||||
}
|
||||
|
||||
capture := ReqRespCapture{ID: 3, ReqPath: "/v1/chat/completions", ReqBody: []byte("hello")}
|
||||
if !mm.addCapture(capture) {
|
||||
t.Fatal("addCapture returned false")
|
||||
}
|
||||
|
||||
got := mm.getCaptureByID(3)
|
||||
if got == nil || !bytes.Equal(got.ReqBody, []byte("hello")) {
|
||||
t.Fatalf("getCaptureByID = %+v", got)
|
||||
}
|
||||
if mm.getCaptureByID(999) != nil {
|
||||
t.Fatal("expected nil for unknown capture ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CaptureDisabled(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 0)
|
||||
if mm.enableCaptures {
|
||||
t.Fatal("captures should be disabled with zero buffer")
|
||||
}
|
||||
if mm.addCapture(ReqRespCapture{ID: 1}) {
|
||||
t.Fatal("addCapture should return false when disabled")
|
||||
}
|
||||
if mm.getCaptureByID(1) != nil {
|
||||
t.Fatal("getCaptureByID should return nil when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CaptureFieldsFor(t *testing.T) {
|
||||
if got := captureFieldsFor("/v1/chat/completions"); got != captureAll {
|
||||
t.Fatalf("default = %b, want captureAll", got)
|
||||
}
|
||||
if got := captureFieldsFor("/v1/audio/speech"); got != captureReqAll|captureRespHeaders {
|
||||
t.Fatalf("/v1/audio/speech = %b", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
|
||||
// the model config leaves concurrencyLimit unset. Matches the legacy
|
||||
// proxy.Process default.
|
||||
const defaultConcurrencyLimit = 10
|
||||
|
||||
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
|
||||
// model-dispatched requests per model. Each model gets a semaphore sized to
|
||||
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
|
||||
// immediately acquire a slot is rejected with 429. Models without a local
|
||||
// config entry (e.g. peer-routed models) are not limited.
|
||||
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
|
||||
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
|
||||
for id, mc := range cfg.Models {
|
||||
limit := defaultConcurrencyLimit
|
||||
if mc.ConcurrencyLimit > 0 {
|
||||
limit = mc.ConcurrencyLimit
|
||||
}
|
||||
semaphores[id] = semaphore.NewWeighted(int64(limit))
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
// fall through for peer models
|
||||
sem, ok := semaphores[data.ModelID]
|
||||
if !ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !sem.TryAcquire(1) {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer sem.Release(1)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
func concurrencyTestReq(model string) *http.Request {
|
||||
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
|
||||
return r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: model, ModelID: model}))
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"m1": {ConcurrencyLimit: 1},
|
||||
},
|
||||
}
|
||||
|
||||
entered := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var once sync.Once
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
once.Do(func() { close(entered) })
|
||||
<-release
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||
|
||||
// First request occupies the only slot.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
|
||||
}()
|
||||
<-entered
|
||||
|
||||
// Second concurrent request is rejected with 429.
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("over-limit status = %d, want 429", w.Code)
|
||||
}
|
||||
|
||||
// Once the slot frees, a new request succeeds.
|
||||
close(release)
|
||||
<-done
|
||||
w = httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("m1"))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("post-release status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{}}
|
||||
|
||||
called := 0
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
h := CreateConcurrencyMiddleware(cfg)(final)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
|
||||
if w.Code != http.StatusOK || called != 1 {
|
||||
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestServer_DecompressBody(t *testing.T) {
|
||||
plain := []byte("hello world")
|
||||
|
||||
var gz bytes.Buffer
|
||||
gw := gzip.NewWriter(&gz)
|
||||
gw.Write(plain)
|
||||
gw.Close()
|
||||
|
||||
var fl bytes.Buffer
|
||||
fw, _ := flate.NewWriter(&fl, flate.DefaultCompression)
|
||||
fw.Write(plain)
|
||||
fw.Close()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
body []byte
|
||||
encoding string
|
||||
}{
|
||||
{"plain", plain, ""},
|
||||
{"gzip", gz.Bytes(), "gzip"},
|
||||
{"deflate", fl.Bytes(), "deflate"},
|
||||
{"unknown passthrough", plain, "br"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got, err := decompressBody(c.body, c.encoding)
|
||||
if err != nil {
|
||||
t.Fatalf("decompressBody: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plain) {
|
||||
t.Errorf("got %q, want %q", got, plain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_FilterAcceptEncoding(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", ""},
|
||||
{"gzip, deflate, br", "gzip, deflate"},
|
||||
{"br, zstd", ""},
|
||||
{"gzip;q=1.0", "gzip;q=1.0"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := filterAcceptEncoding(c.in); got != c.want {
|
||||
t.Errorf("filterAcceptEncoding(%q) = %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_BodyCopier_Flush(t *testing.T) {
|
||||
bc := newBodyCopier(httptest.NewRecorder())
|
||||
bc.Write([]byte("data"))
|
||||
bc.Flush()
|
||||
if bc.Status() != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", bc.Status())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HeaderMapAndRedact(t *testing.T) {
|
||||
h := http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
"Authorization": {"Bearer secret"},
|
||||
"X-Api-Key": {"key123"},
|
||||
}
|
||||
m := headerMap(h)
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Errorf("Content-Type = %q", m["Content-Type"])
|
||||
}
|
||||
|
||||
redactHeaders(m)
|
||||
if m["Authorization"] != "[REDACTED]" || m["X-Api-Key"] != "[REDACTED]" {
|
||||
t.Errorf("sensitive headers not redacted: %v", m)
|
||||
}
|
||||
if m["Content-Type"] != "application/json" {
|
||||
t.Error("non-sensitive header should not be redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StripVersionPrefix(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/v/v1/chat", nil)
|
||||
stripVersionPrefix(r)
|
||||
if r.URL.Path != "/v1/chat" {
|
||||
t.Errorf("path = %q, want /v1/chat", r.URL.Path)
|
||||
}
|
||||
|
||||
r2 := httptest.NewRequest(http.MethodGet, "/v1/chat", nil)
|
||||
stripVersionPrefix(r2)
|
||||
if r2.URL.Path != "/v1/chat" {
|
||||
t.Errorf("path = %q, want unchanged", r2.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CloseStreams(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.CloseStreams()
|
||||
select {
|
||||
case <-s.shutdownCtx.Done():
|
||||
default:
|
||||
t.Error("CloseStreams did not cancel shutdown context")
|
||||
}
|
||||
s.CloseStreams() // idempotent
|
||||
}
|
||||
|
||||
func TestServer_HandleUIAndFavicon(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
for _, path := range []string{"/ui/", "/favicon.ico"} {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
// The embedded ui_dist only carries placeholder.txt in test builds, so
|
||||
// these resolve to 404 — the handlers still execute end to end.
|
||||
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
|
||||
t.Errorf("%s: status = %d", path, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleAPIUnloadAll(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if local.unloadCalls.Load() != 1 {
|
||||
t.Errorf("unloadCalls = %d, want 1", local.unloadCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleAPIUnloadModel(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
|
||||
|
||||
t.Run("known model", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/m1", nil))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown model 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/unload/nope", nil))
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleAPICapture(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.metrics = newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
|
||||
s.metrics.addCapture(ReqRespCapture{ID: 42, ReqPath: "/v1/chat/completions"})
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/42", nil))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if !bytes.Contains(w.Body.Bytes(), []byte("/v1/chat/completions")) {
|
||||
t.Errorf("body = %q", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/999", nil))
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid id", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/captures/abc", nil))
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want 400", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// CreateFilterMiddleware returns middleware that applies per-model request-body
|
||||
// filters to JSON requests before they are forwarded upstream:
|
||||
//
|
||||
// - UseModelName rewrite (issue #69)
|
||||
// - StripParams removal (issue #174)
|
||||
// - SetParams injection (issue #453)
|
||||
// - SetParamsByID per-alias overrides
|
||||
//
|
||||
// Non-JSON requests (GET, multipart forms) pass through untouched. The buffered
|
||||
// body is re-attached with Content-Length / Transfer-Encoding cleanup so the
|
||||
// downstream reverse proxy forwards the correct bytes (see issue #11).
|
||||
func CreateFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), "application/json") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
useModelName, filters, ok := resolveFilters(cfg, data.Model)
|
||||
if !ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, "could not read request body")
|
||||
return
|
||||
}
|
||||
|
||||
body, err = applyFilters(body, data.Model, useModelName, filters)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
r.Header.Del("Transfer-Encoding")
|
||||
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
r.ContentLength = int64(len(body))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CreateFormFilterMiddleware returns middleware that applies the UseModelName
|
||||
// rewrite (issue #69) to multipart/form-data requests before they are forwarded
|
||||
// upstream. JSON-body filters (StripParams, SetParams) do not apply to form
|
||||
// endpoints; only the "model" field is rewritten.
|
||||
//
|
||||
// Non-multipart requests pass through untouched. When a rewrite is needed the
|
||||
// form is reconstructed and re-attached with Content-Type / Content-Length
|
||||
// cleanup so the downstream reverse proxy forwards the correct bytes.
|
||||
func CreateFormFilterMiddleware(cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
useModelName, _, ok := resolveFilters(cfg, data.Model)
|
||||
if !ok || useModelName == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, useModelName)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
r.MultipartForm = nil
|
||||
r.Header.Del("Transfer-Encoding")
|
||||
r.Header.Set("Content-Type", contentType)
|
||||
r.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
r.ContentLength = int64(len(body))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteMultipartModel reconstructs a multipart form, replacing the "model"
|
||||
// field value with useModelName. It returns the encoded body and the matching
|
||||
// Content-Type header (which carries the generated boundary).
|
||||
func rewriteMultipartModel(form *multipart.Form, useModelName string) ([]byte, string, error) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
|
||||
for key, values := range form.Value {
|
||||
for _, value := range values {
|
||||
if key == "model" {
|
||||
value = useModelName
|
||||
}
|
||||
field, err := mw.CreateFormField(key)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error recreating form field %s: %w", key, err)
|
||||
}
|
||||
if _, err := field.Write([]byte(value)); err != nil {
|
||||
return nil, "", fmt.Errorf("error writing form field %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for key, headers := range form.File {
|
||||
for _, fh := range headers {
|
||||
part, err := mw.CreateFormFile(key, fh.Filename)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error recreating form file %s: %w", key, err)
|
||||
}
|
||||
file, err := fh.Open()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error opening uploaded file %s: %w", key, err)
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
file.Close()
|
||||
return nil, "", fmt.Errorf("error copying file data %s: %w", key, err)
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, "", fmt.Errorf("error finalizing multipart form: %w", err)
|
||||
}
|
||||
return buf.Bytes(), mw.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
// resolveFilters returns the filter settings for a requested model. UseModelName
|
||||
// only applies to local models; peers carry filters but no name rewrite.
|
||||
func resolveFilters(cfg config.Config, requested string) (useModelName string, filters config.Filters, ok bool) {
|
||||
if realName, found := cfg.RealModelName(requested); found {
|
||||
mc := cfg.Models[realName]
|
||||
return mc.UseModelName, mc.Filters.Filters, true
|
||||
}
|
||||
for _, peer := range cfg.Peers {
|
||||
for _, m := range peer.Models {
|
||||
if m == requested {
|
||||
return "", peer.Filters, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", config.Filters{}, false
|
||||
}
|
||||
|
||||
// applyFilters rewrites the JSON body in place. Order matches the legacy
|
||||
// ProxyManager: useModelName, stripParams, setParams, then setParamsByID (which
|
||||
// can override setParams).
|
||||
func applyFilters(body []byte, requested, useModelName string, f config.Filters) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
if useModelName != "" {
|
||||
if body, err = sjson.SetBytes(body, "model", useModelName); err != nil {
|
||||
return nil, fmt.Errorf("error rewriting model name in JSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, param := range f.SanitizedStripParams() {
|
||||
if body, err = sjson.DeleteBytes(body, param); err != nil {
|
||||
return nil, fmt.Errorf("error stripping parameter %s from request", param)
|
||||
}
|
||||
}
|
||||
|
||||
setParams, setKeys := f.SanitizedSetParams()
|
||||
for _, key := range setKeys {
|
||||
if body, err = sjson.SetBytes(body, key, setParams[key]); err != nil {
|
||||
return nil, fmt.Errorf("error setting parameter %s in request", key)
|
||||
}
|
||||
}
|
||||
|
||||
byID, byIDKeys := f.SanitizedSetParamsByID(requested)
|
||||
for _, key := range byIDKeys {
|
||||
if body, err = sjson.SetBytes(body, key, byID[key]); err != nil {
|
||||
return nil, fmt.Errorf("error setting parameter %s in request", key)
|
||||
}
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestServer_ApplyFilters(t *testing.T) {
|
||||
t.Run("useModelName rewrite", func(t *testing.T) {
|
||||
out, err := applyFilters([]byte(`{"model":"alias","temp":1}`), "alias", "real-model", config.Filters{})
|
||||
if err != nil {
|
||||
t.Fatalf("applyFilters: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "model").String(); got != "real-model" {
|
||||
t.Errorf("model = %q, want real-model", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strip and set params", func(t *testing.T) {
|
||||
f := config.Filters{
|
||||
StripParams: "temperature",
|
||||
SetParams: map[string]any{"top_p": 0.9},
|
||||
}
|
||||
out, err := applyFilters([]byte(`{"model":"m","temperature":0.7}`), "m", "", f)
|
||||
if err != nil {
|
||||
t.Fatalf("applyFilters: %v", err)
|
||||
}
|
||||
if gjson.GetBytes(out, "temperature").Exists() {
|
||||
t.Error("temperature should be stripped")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.9 {
|
||||
t.Errorf("top_p = %v, want 0.9", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("setParamsByID overrides setParams", func(t *testing.T) {
|
||||
f := config.Filters{
|
||||
SetParams: map[string]any{"top_p": 0.5},
|
||||
SetParamsByID: map[string]map[string]any{"alias": {"top_p": 0.1}},
|
||||
}
|
||||
out, err := applyFilters([]byte(`{"model":"alias"}`), "alias", "", f)
|
||||
if err != nil {
|
||||
t.Fatalf("applyFilters: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "top_p").Float(); got != 0.1 {
|
||||
t.Errorf("top_p = %v, want 0.1", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_RewriteMultipartModel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
mw.WriteField("model", "old-name")
|
||||
mw.WriteField("language", "en")
|
||||
fw, _ := mw.CreateFormFile("file", "audio.wav")
|
||||
fw.Write([]byte("RIFFdata"))
|
||||
mw.Close()
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
t.Fatalf("ParseMultipartForm: %v", err)
|
||||
}
|
||||
|
||||
body, contentType, err := rewriteMultipartModel(r.MultipartForm, "new-name")
|
||||
if err != nil {
|
||||
t.Fatalf("rewriteMultipartModel: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := multipart.NewReader(bytes.NewReader(body), boundaryOf(t, contentType)).ReadForm(32 << 20)
|
||||
if err != nil {
|
||||
t.Fatalf("re-parse: %v", err)
|
||||
}
|
||||
if got := parsed.Value["model"][0]; got != "new-name" {
|
||||
t.Errorf("model = %q, want new-name", got)
|
||||
}
|
||||
if got := parsed.Value["language"][0]; got != "en" {
|
||||
t.Errorf("language = %q, want en", got)
|
||||
}
|
||||
fh := parsed.File["file"][0]
|
||||
f, _ := fh.Open()
|
||||
data, _ := io.ReadAll(f)
|
||||
f.Close()
|
||||
if string(data) != "RIFFdata" {
|
||||
t.Errorf("file data = %q, want RIFFdata", data)
|
||||
}
|
||||
}
|
||||
|
||||
func boundaryOf(t *testing.T, contentType string) string {
|
||||
t.Helper()
|
||||
_, params, ok := strings.Cut(contentType, "boundary=")
|
||||
if !ok {
|
||||
t.Fatalf("no boundary in %q", contentType)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func TestServer_FormFilterMiddleware(t *testing.T) {
|
||||
cfg := config.Config{Models: map[string]config.ModelConfig{
|
||||
"whisper": {UseModelName: "whisper-large-v3"},
|
||||
}}
|
||||
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
mw.WriteField("model", "whisper")
|
||||
fw, _ := mw.CreateFormFile("file", "a.wav")
|
||||
fw.Write([]byte("xx"))
|
||||
mw.Close()
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
var gotModel string
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseMultipartForm(32 << 20)
|
||||
gotModel = r.MultipartForm.Value["model"][0]
|
||||
})
|
||||
CreateFormFilterMiddleware(cfg)(final).ServeHTTP(httptest.NewRecorder(), r)
|
||||
|
||||
if gotModel != "whisper-large-v3" {
|
||||
t.Errorf("model rewritten to %q, want whisper-large-v3", gotModel)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// inflightCounter tracks the number of in-flight model-dispatched requests.
|
||||
type inflightCounter struct {
|
||||
total atomic.Int64
|
||||
}
|
||||
|
||||
func (c *inflightCounter) Increment() int64 { return c.total.Add(1) }
|
||||
func (c *inflightCounter) Decrement() int64 { return c.total.Add(-1) }
|
||||
func (c *inflightCounter) Current() int64 { return c.total.Load() }
|
||||
|
||||
// CreateInflightMiddleware returns middleware that increments the counter on
|
||||
// entry and decrements on exit, emitting an InFlightRequestsEvent for each.
|
||||
func CreateInflightMiddleware(c *inflightCounter) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Increment())})
|
||||
defer func() {
|
||||
event.Emit(shared.InFlightRequestsEvent{Total: int(c.Decrement())})
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// NewLoggers builds the proxy, upstream, and combined (mux) log monitors,
|
||||
// wiring each one's output per the logToStdout config value. The proxy and
|
||||
// upstream monitors write into muxlog (rather than os.Stdout directly) so
|
||||
// muxlog accumulates a combined history for the /logs endpoints, while each
|
||||
// monitor keeps its own per-source history and event subscribers.
|
||||
//
|
||||
// Behaviour matches the legacy ProxyManager:
|
||||
//
|
||||
// - none: everything discarded
|
||||
// - both: proxy + upstream both routed to muxlog -> stdout
|
||||
// - upstream: only upstream routed to muxlog -> stdout; proxy discarded
|
||||
// - proxy: only proxy routed to muxlog -> stdout; upstream discarded
|
||||
//
|
||||
// An empty or unrecognised value behaves like "proxy".
|
||||
func NewLoggers(logToStdout string) (muxlog, proxylog, upstreamlog *logmon.Monitor) {
|
||||
switch logToStdout {
|
||||
case config.LogToStdoutNone:
|
||||
muxlog = logmon.NewWriter(io.Discard)
|
||||
proxylog = logmon.NewWriter(io.Discard)
|
||||
upstreamlog = logmon.NewWriter(io.Discard)
|
||||
case config.LogToStdoutBoth:
|
||||
muxlog = logmon.NewWriter(os.Stdout)
|
||||
proxylog = logmon.NewWriter(muxlog)
|
||||
upstreamlog = logmon.NewWriter(muxlog)
|
||||
case config.LogToStdoutUpstream:
|
||||
muxlog = logmon.NewWriter(os.Stdout)
|
||||
proxylog = logmon.NewWriter(io.Discard)
|
||||
upstreamlog = logmon.NewWriter(muxlog)
|
||||
default:
|
||||
// config.LogToStdoutProxy, and the fallback for an unset value.
|
||||
muxlog = logmon.NewWriter(os.Stdout)
|
||||
proxylog = logmon.NewWriter(muxlog)
|
||||
upstreamlog = logmon.NewWriter(io.Discard)
|
||||
}
|
||||
return muxlog, proxylog, upstreamlog
|
||||
}
|
||||
|
||||
// handleLogs serves the historical proxy/upstream log. HTML clients are
|
||||
// redirected to the UI.
|
||||
func (s *Server) handleLogs(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.Header.Get("Accept"), "text/html") {
|
||||
http.Redirect(w, r, "/ui/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write(s.muxlog.GetHistory())
|
||||
}
|
||||
|
||||
// getLogger resolves a log monitor by id. An empty id maps to the combined
|
||||
// muxlog; "proxy" and "upstream" select the respective monitors.
|
||||
func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
|
||||
switch logMonitorID {
|
||||
case "":
|
||||
return s.muxlog, nil
|
||||
case "proxy":
|
||||
return s.proxylog, nil
|
||||
case "upstream":
|
||||
return s.upstreamlog, nil
|
||||
default:
|
||||
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found {
|
||||
if log, ok := s.local.ProcessLogger(modelID); ok {
|
||||
return log, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogStream tails a log monitor: it writes the history then streams live
|
||||
// log data until the client disconnects or the server shuts down.
|
||||
func (s *Server) handleLogStream(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering streamed logs
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorID := strings.TrimPrefix(r.PathValue("logMonitorID"), "/")
|
||||
// Strip a query string if it leaked into the path segment.
|
||||
if idx := strings.Index(logMonitorID, "?"); idx != -1 {
|
||||
logMonitorID = logMonitorID[:idx]
|
||||
}
|
||||
|
||||
logger, err := s.getLogger(logMonitorID)
|
||||
if err != nil {
|
||||
router.SendResponse(w, r, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
_, skipHistory := r.URL.Query()["no-history"]
|
||||
if !skipHistory {
|
||||
if history := logger.GetHistory(); len(history) != 0 {
|
||||
w.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
cancelSub := logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
}
|
||||
})
|
||||
defer cancelSub()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-s.shutdownCtx.Done():
|
||||
return
|
||||
case data := <-sendChan:
|
||||
w.Write(data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestLogPathSkips lists path prefixes excluded from the access log because
|
||||
// they are polled frequently and would drown out useful entries.
|
||||
var requestLogPathSkips = []string{"/wol-health", "/api/performance", "/metrics"}
|
||||
|
||||
// statusRecorder wraps an http.ResponseWriter to capture the response status
|
||||
// code and the number of body bytes written, so the access log can report
|
||||
// them. Flush is forwarded so streaming handlers (SSE) still work.
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) WriteHeader(code int) {
|
||||
sr.status = code
|
||||
sr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) Write(b []byte) (int, error) {
|
||||
n, err := sr.ResponseWriter.Write(b)
|
||||
sr.size += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) Flush() {
|
||||
if f, ok := sr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// clientIP resolves the originating client address, preferring proxy headers
|
||||
// over the raw connection address.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if first, _, found := strings.Cut(xff, ","); found {
|
||||
return strings.TrimSpace(first)
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
if xr := r.Header.Get("X-Real-IP"); xr != "" {
|
||||
return strings.TrimSpace(xr)
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// CreateRequestLogMiddleware returns middleware that records one access-log
|
||||
// line per request to proxylog, in the legacy format:
|
||||
//
|
||||
// clientIP "METHOD PATH PROTO" status bodySize "UA" duration
|
||||
//
|
||||
// Frequently-polled health/metrics paths are skipped. The path is captured
|
||||
// before next runs because /upstream rewrites the request URL in place.
|
||||
func CreateRequestLogMiddleware(proxylog *logmon.Monitor) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, prefix := range requestLogPathSkips {
|
||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
ip, method, path, proto, ua := clientIP(r), r.Method, r.URL.Path, r.Proto, r.UserAgent()
|
||||
|
||||
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(rec, r)
|
||||
|
||||
proxylog.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||
ip, method, path, proto, rec.status, rec.size, ua, time.Since(start))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
)
|
||||
|
||||
func TestServer_NewLoggers(t *testing.T) {
|
||||
t.Run("proxy mode routes proxy into muxlog, discards upstream", func(t *testing.T) {
|
||||
mux, proxy, upstream := NewLoggers(config.LogToStdoutProxy)
|
||||
proxy.Info("PROXYLINE")
|
||||
upstream.Info("UPSTREAMLINE")
|
||||
h := string(mux.GetHistory())
|
||||
if !strings.Contains(h, "PROXYLINE") {
|
||||
t.Errorf("muxlog missing proxy line: %q", h)
|
||||
}
|
||||
if strings.Contains(h, "UPSTREAMLINE") {
|
||||
t.Errorf("muxlog should not contain upstream line: %q", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("both mode routes proxy and upstream into muxlog", func(t *testing.T) {
|
||||
mux, proxy, upstream := NewLoggers(config.LogToStdoutBoth)
|
||||
proxy.Info("PROXYLINE")
|
||||
upstream.Info("UPSTREAMLINE")
|
||||
h := string(mux.GetHistory())
|
||||
if !strings.Contains(h, "PROXYLINE") || !strings.Contains(h, "UPSTREAMLINE") {
|
||||
t.Errorf("muxlog history = %q", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("none mode discards everything from muxlog", func(t *testing.T) {
|
||||
mux, proxy, upstream := NewLoggers(config.LogToStdoutNone)
|
||||
proxy.Info("PROXYLINE")
|
||||
upstream.Info("UPSTREAMLINE")
|
||||
if len(mux.GetHistory()) != 0 {
|
||||
t.Errorf("muxlog should be empty, got %q", mux.GetHistory())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleLogs_Plain(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
s.muxlog.Write([]byte("a log line"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d", w.Code)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "text/plain" {
|
||||
t.Errorf("Content-Type = %q, want text/plain", ct)
|
||||
}
|
||||
if w.Body.String() != "a log line" {
|
||||
t.Errorf("body = %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleLogs_HTMLRedirect(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
req.Header.Set("Accept", "text/html")
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Fatalf("status = %d, want 302", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Location"); got != "/ui/" {
|
||||
t.Errorf("Location = %q, want /ui/", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ClientIP(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
setup func(*http.Request)
|
||||
want string
|
||||
}{
|
||||
{"remote addr", func(r *http.Request) { r.RemoteAddr = "10.0.0.5:1234" }, "10.0.0.5"},
|
||||
{"x-forwarded-for", func(r *http.Request) {
|
||||
r.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8")
|
||||
}, "1.2.3.4"},
|
||||
{"x-real-ip", func(r *http.Request) { r.Header.Set("X-Real-IP", "9.9.9.9") }, "9.9.9.9"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.RemoteAddr = ""
|
||||
c.setup(r)
|
||||
if got := clientIP(r); got != c.want {
|
||||
t.Errorf("clientIP() = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RequestLogMiddleware(t *testing.T) {
|
||||
proxylog := logmon.NewWriter(io.Discard)
|
||||
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte("hello"))
|
||||
})
|
||||
mw := CreateRequestLogMiddleware(proxylog)
|
||||
|
||||
t.Run("logs request", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
r.RemoteAddr = "192.168.1.1:5000"
|
||||
mw(final).ServeHTTP(httptest.NewRecorder(), r)
|
||||
|
||||
line := string(proxylog.GetHistory())
|
||||
for _, want := range []string{"192.168.1.1", "POST /v1/chat/completions", "201", "5"} {
|
||||
if !strings.Contains(line, want) {
|
||||
t.Errorf("log line %q missing %q", line, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for _, path := range []string{"/wol-health", "/api/performance", "/metrics"} {
|
||||
t.Run("skips "+path, func(t *testing.T) {
|
||||
skipLog := logmon.NewWriter(io.Discard)
|
||||
skipMW := CreateRequestLogMiddleware(skipLog)
|
||||
skipMW(final).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if len(skipLog.GetHistory()) != 0 {
|
||||
t.Errorf("%s should not be logged; got %q", path, skipLog.GetHistory())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/cache"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/ring"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TokenMetrics holds token usage and performance metrics.
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
}
|
||||
|
||||
// ActivityLogEntry represents parsed token statistics from llama-server logs.
|
||||
type ActivityLogEntry struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
ReqPath string `json:"req_path"`
|
||||
RespContentType string `json:"resp_content_type"`
|
||||
RespStatusCode int `json:"resp_status_code"`
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
// ActivityLogEvent carries a single activity log entry to event subscribers.
|
||||
type ActivityLogEvent struct {
|
||||
Metrics ActivityLogEntry
|
||||
}
|
||||
|
||||
func (e ActivityLogEvent) Type() uint32 {
|
||||
return shared.ActivityLogEventID
|
||||
}
|
||||
|
||||
// metricsMonitor parses upstream responses for token statistics, keeps a
|
||||
// bounded in-memory ring of recent activity, and (when captures are enabled)
|
||||
// stores zstd+CBOR-compressed request/response captures in a sized cache.
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics ring.Buffer[ActivityLogEntry]
|
||||
nextID int
|
||||
logger *logmon.Monitor
|
||||
|
||||
enableCaptures bool
|
||||
captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a metricsMonitor retaining up to maxMetrics entries.
|
||||
// captureBufferMB is the capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *logmon.Monitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000
|
||||
}
|
||||
mm := &metricsMonitor{
|
||||
logger: logger,
|
||||
metrics: ring.NewBuffer[ActivityLogEntry](maxMetrics),
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
}
|
||||
if captureBufferMB > 0 {
|
||||
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
|
||||
}
|
||||
return mm
|
||||
}
|
||||
|
||||
// queueMetrics adds a metric to the ring and returns its assigned ID.
|
||||
func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics.Push(metric)
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// emitMetric publishes an ActivityLogEvent for the given metric.
|
||||
func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
|
||||
event.Emit(ActivityLogEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics.
|
||||
func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := mp.metrics.Slice()
|
||||
if result == nil {
|
||||
return []ActivityLogEntry{}
|
||||
}
|
||||
if mp.captureCache != nil {
|
||||
for i := range result {
|
||||
result[i].HasCapture = mp.captureCache.Has(result[i].ID)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns the current metrics as a JSON array.
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
return json.Marshal(mp.getMetrics())
|
||||
}
|
||||
|
||||
// record parses a completed response body and stores/emits an activity entry.
|
||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
||||
// requests, with cf controlling which request/response parts are retained.
|
||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
ReqPath: r.URL.Path,
|
||||
RespContentType: recorder.Header().Get("Content-Type"),
|
||||
RespStatusCode: recorder.Status(),
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
queueAndEmit := func() {
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
mp.emitMetric(tm)
|
||||
}
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
decoded, err := decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
body = decoded
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsed.Tokens
|
||||
tm.DurationMs = parsed.DurationMs
|
||||
}
|
||||
} else if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// /infill responses are arrays; timings live in the last element (#463).
|
||||
if strings.HasPrefix(r.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
} else {
|
||||
tm.Tokens = parsedMetrics.Tokens
|
||||
tm.DurationMs = parsedMetrics.DurationMs
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", r.URL.Path)
|
||||
}
|
||||
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
if mp.enableCaptures {
|
||||
capture := ReqRespCapture{
|
||||
ID: tm.ID,
|
||||
ReqPath: r.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
}
|
||||
if cf&captureReqBody != 0 {
|
||||
capture.ReqBody = reqBody
|
||||
}
|
||||
if cf&captureRespHeaders != 0 {
|
||||
capture.RespHeaders = headerMap(recorder.Header())
|
||||
redactHeaders(capture.RespHeaders)
|
||||
delete(capture.RespHeaders, "Content-Encoding")
|
||||
}
|
||||
if cf&captureRespBody != 0 {
|
||||
capture.RespBody = body
|
||||
}
|
||||
if mp.addCapture(capture) {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
mp.emitMetric(tm)
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
var usagePaths = []string{"usage", "response.usage", "message.usage"}
|
||||
|
||||
// extractUsageTokens reads input/output/cached token counts from a usage
|
||||
// gjson.Result, handling the field-name differences across endpoints.
|
||||
func extractUsageTokens(usage gjson.Result) (input, output, cached int64, ok bool) {
|
||||
cached = -1
|
||||
if !usage.Exists() {
|
||||
return
|
||||
}
|
||||
|
||||
if v := usage.Get("prompt_tokens"); v.Exists() {
|
||||
input = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens"); v.Exists() {
|
||||
input = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("completion_tokens"); v.Exists() {
|
||||
output = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("output_tokens"); v.Exists() {
|
||||
output = v.Int()
|
||||
ok = true
|
||||
}
|
||||
|
||||
if v := usage.Get("cache_read_input_tokens"); v.Exists() {
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("input_tokens_details.cached_tokens"); v.Exists() {
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
} else if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
|
||||
cached = v.Int()
|
||||
ok = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
|
||||
var (
|
||||
inputTokens, outputTokens int64
|
||||
cachedTokens int64 = -1
|
||||
hasAny bool
|
||||
timings gjson.Result
|
||||
)
|
||||
|
||||
prefix := []byte("data:")
|
||||
for offset := 0; offset < len(body); {
|
||||
nl := bytes.IndexByte(body[offset:], '\n')
|
||||
var line []byte
|
||||
if nl == -1 {
|
||||
line = body[offset:]
|
||||
offset = len(body)
|
||||
} else {
|
||||
line = body[offset : offset+nl]
|
||||
offset += nl + 1
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
continue
|
||||
}
|
||||
parsed := gjson.ParseBytes(data)
|
||||
|
||||
for _, path := range usagePaths {
|
||||
u := parsed.Get(path)
|
||||
if !u.Exists() {
|
||||
continue
|
||||
}
|
||||
i, o, c, ok := extractUsageTokens(u)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hasAny = true
|
||||
if i > 0 {
|
||||
inputTokens = i
|
||||
}
|
||||
if o > 0 {
|
||||
outputTokens = o
|
||||
}
|
||||
if c >= 0 {
|
||||
cachedTokens = c
|
||||
}
|
||||
}
|
||||
if t := parsed.Get("timings"); t.Exists() {
|
||||
timings = t
|
||||
hasAny = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAny {
|
||||
return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
return buildMetrics(modelID, start, inputTokens, outputTokens, cachedTokens, timings), nil
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
|
||||
input, output, cached, _ := extractUsageTokens(usage)
|
||||
return buildMetrics(modelID, start, input, output, cached, timings), nil
|
||||
}
|
||||
|
||||
// buildMetrics composes an ActivityLogEntry from accumulated token counts and
|
||||
// optional llama-server timings (which override input/output and provide rates).
|
||||
func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, cachedTokens int64, timings gjson.Result) ActivityLogEntry {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
outputTokens = timings.Get("predicted_n").Int()
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
},
|
||||
DurationMs: durationMs,
|
||||
}
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on the Content-Encoding header.
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters Accept-Encoding to only gzip/deflate so response
|
||||
// bodies remain decompressible for metrics parsing.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
// responseBodyCopier tees the upstream response to the client while buffering
|
||||
// it for metrics parsing. Status defaults to 200 until WriteHeader is called.
|
||||
type responseBodyCopier struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
status int
|
||||
wroteHeader bool
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w http.ResponseWriter) *responseBodyCopier {
|
||||
buf := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: buf,
|
||||
tee: io.MultiWriter(w, buf),
|
||||
status: http.StatusOK,
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
if w.wroteHeader {
|
||||
return
|
||||
}
|
||||
w.wroteHeader = true
|
||||
w.status = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Flush forwards to the underlying writer so streaming responses still flush.
|
||||
func (w *responseBodyCopier) Flush() {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Status() int { return w.status }
|
||||
func (w *responseBodyCopier) StartTime() time.Time { return w.start }
|
||||
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// CreateMetricsMiddleware returns middleware that records token metrics for
|
||||
// model-dispatched POST requests. It resolves the model, tees the response into
|
||||
// a buffer, and parses token usage once the upstream handler returns.
|
||||
func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if mm == nil || r.Method != http.MethodPost {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the model now so downstream dispatch hits the context
|
||||
// fast path; FetchContext restores the request body.
|
||||
data, err := router.FetchContext(r, cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
// Buffer the request body/headers for capture before dispatch
|
||||
// consumes them.
|
||||
cf := captureFieldsFor(r.URL.Path)
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mm.enableCaptures {
|
||||
if cf&captureReqBody != 0 && r.Body != nil {
|
||||
if buffered, err := io.ReadAll(r.Body); err == nil {
|
||||
reqBody = buffered
|
||||
r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(reqBody))
|
||||
}
|
||||
}
|
||||
if cf&captureReqHeaders != 0 {
|
||||
reqHeaders = headerMap(r.Header)
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
// Restrict Accept-Encoding to encodings we can decompress so the
|
||||
// buffered response body stays parseable.
|
||||
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
|
||||
r.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(w)
|
||||
next.ServeHTTP(recorder, r)
|
||||
mm.record(data.ModelID, r, recorder, cf, reqBody, reqHeaders)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestServer_ParseMetrics_ChatCompletions(t *testing.T) {
|
||||
body := `{"usage":{"prompt_tokens":12,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}`
|
||||
parsed := gjson.Parse(body)
|
||||
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
|
||||
if err != nil {
|
||||
t.Fatalf("parseMetrics: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 12 || entry.Tokens.OutputTokens != 7 || entry.Tokens.CachedTokens != 4 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ParseMetrics_Timings(t *testing.T) {
|
||||
body := `{"timings":{"prompt_n":20,"predicted_n":50,"prompt_per_second":100.0,"predicted_per_second":40.0,"prompt_ms":200,"predicted_ms":1250,"cache_n":8}}`
|
||||
parsed := gjson.Parse(body)
|
||||
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), parsed.Get("timings"))
|
||||
if err != nil {
|
||||
t.Fatalf("parseMetrics: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 20 || entry.Tokens.OutputTokens != 50 || entry.Tokens.CachedTokens != 8 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
if entry.Tokens.TokensPerSecond != 40.0 || entry.Tokens.PromptPerSecond != 100.0 {
|
||||
t.Fatalf("rates = %+v", entry.Tokens)
|
||||
}
|
||||
if entry.DurationMs != 1450 {
|
||||
t.Fatalf("DurationMs = %d, want 1450", entry.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessStreamingResponse(t *testing.T) {
|
||||
body := []byte("data: {\"choices\":[{}]}\n\n" +
|
||||
"data: {\"usage\":{\"prompt_tokens\":15,\"completion_tokens\":33}}\n\n" +
|
||||
"data: [DONE]\n\n")
|
||||
entry, err := processStreamingResponse("m", time.Now(), body)
|
||||
if err != nil {
|
||||
t.Fatalf("processStreamingResponse: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 15 || entry.Tokens.OutputTokens != 33 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
|
||||
if _, err := processStreamingResponse("m", time.Now(), []byte("data: [DONE]\n\n")); err == nil {
|
||||
t.Fatal("expected error for stream with no usage data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||
// /infill responses are arrays; timings live in the last element.
|
||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||
parsed := gjson.Parse(body)
|
||||
timings := parsed.Get("timings")
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
entry, err := parseMetrics("m", time.Now(), parsed.Get("usage"), timings)
|
||||
if err != nil {
|
||||
t.Fatalf("parseMetrics: %v", err)
|
||||
}
|
||||
if entry.Tokens.InputTokens != 5 || entry.Tokens.OutputTokens != 9 {
|
||||
t.Fatalf("tokens = %+v", entry.Tokens)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/chain"
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/perf"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
)
|
||||
|
||||
// Server owns the HTTP mux, cross-cutting middleware, and the local/peer model
|
||||
// dispatch. It supersedes router.Server: it builds the local and peer routers
|
||||
// directly and dispatches between them itself.
|
||||
type Server struct {
|
||||
cfg config.Config
|
||||
|
||||
muxlog *logmon.Monitor
|
||||
proxylog *logmon.Monitor
|
||||
upstreamlog *logmon.Monitor
|
||||
|
||||
perf *perf.Monitor
|
||||
inflight *inflightCounter
|
||||
metrics *metricsMonitor
|
||||
build BuildInfo
|
||||
|
||||
local router.LocalRouter
|
||||
peer router.Router
|
||||
|
||||
mux *http.ServeMux
|
||||
handler http.Handler
|
||||
|
||||
shutdownCtx context.Context
|
||||
shutdownFn context.CancelFunc
|
||||
shuttingDown atomic.Bool
|
||||
}
|
||||
|
||||
// modelPostJSONRoutes are endpoints with a model id in the JSON request body.
|
||||
var modelPostJSONRoutes = []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/responses",
|
||||
"/v1/completions",
|
||||
"/v1/messages",
|
||||
"/v1/messages/count_tokens",
|
||||
"/v1/embeddings",
|
||||
"/reranking",
|
||||
"/rerank",
|
||||
"/v1/rerank",
|
||||
"/v1/reranking",
|
||||
"/infill",
|
||||
"/completion",
|
||||
"/v1/audio/speech",
|
||||
"/v1/audio/voices",
|
||||
"/v1/images/generations",
|
||||
"/sdapi/v1/txt2img",
|
||||
"/sdapi/v1/img2img",
|
||||
|
||||
// versionless routes, the /v/ is stripped before the request is forwarded upstream
|
||||
// see issue #728
|
||||
"/v/chat/completions",
|
||||
"/v/responses",
|
||||
"/v/completions",
|
||||
"/v/messages",
|
||||
"/v/messages/count_tokens",
|
||||
"/v/embeddings",
|
||||
"/v/rerank",
|
||||
"/v/reranking",
|
||||
}
|
||||
|
||||
// modelPostFormRoutes are multipart/form-data endpoints with a model id in the form data
|
||||
var modelPostFormRoutes = []string{
|
||||
"/v1/audio/transcriptions",
|
||||
"/v1/images/edits",
|
||||
}
|
||||
|
||||
// modelGetRoutes are model-dispatched GET endpoints (the model arrives as a
|
||||
// query parameter).
|
||||
var modelGetRoutes = []string{
|
||||
"/v1/audio/voices",
|
||||
"/sdapi/v1/loras",
|
||||
}
|
||||
|
||||
// BuildInfo carries version metadata surfaced by GET /api/version.
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
Commit string
|
||||
Date string
|
||||
}
|
||||
|
||||
func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, upstreamlog *logmon.Monitor, perfMon *perf.Monitor, build BuildInfo) (*Server, error) {
|
||||
var local router.LocalRouter
|
||||
var err error
|
||||
|
||||
if cfg.Matrix != nil {
|
||||
local, err = router.NewMatrix(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating matrix router: %w", err)
|
||||
}
|
||||
} else {
|
||||
local, err = router.NewGroup(cfg, proxylog, upstreamlog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating group router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := router.NewPeer(cfg, proxylog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating peer router: %w", err)
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
muxlog: muxlog,
|
||||
proxylog: proxylog,
|
||||
upstreamlog: upstreamlog,
|
||||
perf: perfMon,
|
||||
inflight: &inflightCounter{},
|
||||
metrics: newMetricsMonitor(proxylog, cfg.MetricsMaxInMemory, cfg.CaptureBuffer),
|
||||
build: build,
|
||||
local: local,
|
||||
peer: peer,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownFn: shutdownFn,
|
||||
}
|
||||
s.routes()
|
||||
s.startPreload()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// localPeerHandler dispatches a model-routed request to the local or peer
|
||||
// router. The model is resolved once via router.FetchContext.
|
||||
func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
stripVersionPrefix(r)
|
||||
|
||||
data, err := router.FetchContext(r, s.cfg)
|
||||
if err != nil {
|
||||
router.SendError(w, r, router.ErrNoModelInContext)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.local.Handles(data.ModelID):
|
||||
s.proxylog.Debugf("dispatch: using local process for model: %s", data.ModelID)
|
||||
s.local.ServeHTTP(w, r)
|
||||
case s.peer.Handles(data.ModelID):
|
||||
s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID)
|
||||
s.peer.ServeHTTP(w, r)
|
||||
default:
|
||||
router.SendError(w, r, router.ErrNoRouterFound)
|
||||
}
|
||||
}
|
||||
|
||||
// stripVersionPrefix rewrites versionless /v/... requests to their /... form
|
||||
// before forwarding upstream (issue #728).
|
||||
func stripVersionPrefix(r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/v/") {
|
||||
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/v")
|
||||
}
|
||||
}
|
||||
|
||||
// routes builds the mux, registers every route, and wraps the mux with the
|
||||
// global CORS middleware.
|
||||
func (s *Server) routes() {
|
||||
authMW := CreateAuthMiddleware(s.cfg)
|
||||
filterMW := CreateFilterMiddleware(s.cfg)
|
||||
formFilterMW := CreateFormFilterMiddleware(s.cfg)
|
||||
|
||||
// Model-dispatched routes get auth + per-model concurrency limiting + body
|
||||
// filters + in-flight tracking + token metrics. concurrencyMW rejects with
|
||||
// 429 before the body filters do any rewrite work. filterMW rewrites JSON
|
||||
// bodies and formFilterMW rewrites multipart bodies; each is a no-op for the
|
||||
// other's Content-Type. Both run before the metrics middleware so it buffers
|
||||
// the rewritten body.
|
||||
modelChain := chain.New(
|
||||
authMW,
|
||||
CreateConcurrencyMiddleware(s.cfg),
|
||||
filterMW,
|
||||
formFilterMW,
|
||||
CreateInflightMiddleware(s.inflight),
|
||||
CreateMetricsMiddleware(s.metrics, s.cfg),
|
||||
)
|
||||
// Custom endpoints only need auth.
|
||||
apiChain := chain.New(authMW)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
dispatch := http.HandlerFunc(s.localPeerHandler)
|
||||
|
||||
for _, path := range modelPostJSONRoutes {
|
||||
mux.Handle("POST "+path, modelChain.Then(dispatch))
|
||||
}
|
||||
for _, path := range modelPostFormRoutes {
|
||||
mux.Handle("POST "+path, modelChain.Then(dispatch))
|
||||
}
|
||||
for _, path := range modelGetRoutes {
|
||||
mux.Handle("GET "+path, modelChain.Then(dispatch))
|
||||
}
|
||||
|
||||
// llama-swap API + custom endpoints.
|
||||
mux.Handle("GET /v1/models", apiChain.ThenFunc(s.handleListModels))
|
||||
mux.Handle("GET /logs", apiChain.ThenFunc(s.handleLogs))
|
||||
mux.Handle("GET /logs/stream", apiChain.ThenFunc(s.handleLogStream))
|
||||
mux.Handle("GET /logs/stream/{logMonitorID...}", apiChain.ThenFunc(s.handleLogStream))
|
||||
|
||||
mux.HandleFunc("GET /health", handleHealth)
|
||||
mux.HandleFunc("GET /wol-health", handleHealth)
|
||||
mux.HandleFunc("GET /{$}", handleRootRedirect)
|
||||
|
||||
// Embedded UI.
|
||||
mux.HandleFunc("GET /ui/", s.handleUI)
|
||||
mux.HandleFunc("GET /favicon.ico", s.handleFavicon)
|
||||
|
||||
// Prometheus metrics (no auth, matches the legacy endpoint).
|
||||
mux.HandleFunc("GET /metrics", s.handleMetrics)
|
||||
|
||||
// Operations endpoints.
|
||||
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
|
||||
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
|
||||
|
||||
// Upstream passthrough.
|
||||
mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
|
||||
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream))
|
||||
|
||||
// API group (API-key protected) consumed by the UI.
|
||||
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
|
||||
mux.Handle("POST /api/models/unload/{model...}", apiChain.ThenFunc(s.handleAPIUnloadModel))
|
||||
mux.Handle("GET /api/events", apiChain.ThenFunc(s.handleAPIEvents))
|
||||
mux.Handle("GET /api/metrics", apiChain.ThenFunc(s.handleAPIMetrics))
|
||||
mux.Handle("GET /api/performance", apiChain.ThenFunc(s.handleAPIPerformance))
|
||||
mux.Handle("GET /api/version", apiChain.ThenFunc(s.handleAPIVersion))
|
||||
mux.Handle("GET /api/captures/{id}", apiChain.ThenFunc(s.handleAPICapture))
|
||||
|
||||
s.mux = mux
|
||||
s.handler = chain.New(CreateRequestLogMiddleware(s.proxylog), CreateCORSMiddleware()).Then(mux)
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// CloseStreams cancels long-lived response streams (Server-Sent Events) so a
|
||||
// graceful httpServer.Shutdown can drain without blocking on them. It does not
|
||||
// tear down routers; call Shutdown for that. Safe to call repeatedly.
|
||||
func (s *Server) CloseStreams() {
|
||||
s.shutdownFn()
|
||||
}
|
||||
|
||||
// Shutdown stops the local and peer routers in parallel. It is idempotent;
|
||||
// repeated calls return nil without re-running shutdown.
|
||||
//
|
||||
// Callers must drain inflight HTTP requests (httpServer.Shutdown) before
|
||||
// calling this, otherwise inflight requests 502 when their processes are torn
|
||||
// down. Call CloseStreams before httpServer.Shutdown so SSE streams do not
|
||||
// block the drain.
|
||||
func (s *Server) Shutdown(timeout time.Duration) error {
|
||||
if !s.shuttingDown.CompareAndSwap(false, true) {
|
||||
return nil
|
||||
}
|
||||
s.shutdownFn()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
var errs []error
|
||||
|
||||
for _, rt := range []router.Router{s.local, s.peer} {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(rt router.Router) {
|
||||
defer wg.Done()
|
||||
if err := rt.Shutdown(timeout); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}(rt)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/router"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
// stubRouter is a minimal router.LocalRouter for Server dispatch tests.
|
||||
type stubRouter struct {
|
||||
models map[string]bool
|
||||
response string
|
||||
shutdownCalls atomic.Int32
|
||||
running map[string]process.ProcessState
|
||||
unloadCalls atomic.Int32
|
||||
loggers map[string]*logmon.Monitor
|
||||
}
|
||||
|
||||
func newStubRouter(models []string, response string) *stubRouter {
|
||||
m := make(map[string]bool, len(models))
|
||||
for _, id := range models {
|
||||
m[id] = true
|
||||
}
|
||||
return &stubRouter{models: m, response: response}
|
||||
}
|
||||
|
||||
func (s *stubRouter) Handles(model string) bool { return s.models[model] }
|
||||
func (s *stubRouter) Shutdown(_ time.Duration) error { s.shutdownCalls.Add(1); return nil }
|
||||
func (s *stubRouter) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(s.response))
|
||||
}
|
||||
|
||||
func (s *stubRouter) RunningModels() map[string]process.ProcessState { return s.running }
|
||||
func (s *stubRouter) Unload(_ time.Duration, _ ...string) { s.unloadCalls.Add(1) }
|
||||
func (s *stubRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
|
||||
if s.loggers != nil {
|
||||
if lg, ok := s.loggers[modelID]; ok {
|
||||
return lg, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// newTestServer wires a Server with stub routers and a built mux.
|
||||
func newTestServer(local router.LocalRouter, peer router.Router) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
proxylog := logmon.NewWriter(io.Discard)
|
||||
s := &Server{
|
||||
cfg: config.Config{},
|
||||
muxlog: logmon.NewWriter(io.Discard),
|
||||
proxylog: proxylog,
|
||||
upstreamlog: logmon.NewWriter(io.Discard),
|
||||
inflight: &inflightCounter{},
|
||||
metrics: newMetricsMonitor(proxylog, 0, 0),
|
||||
local: local,
|
||||
peer: peer,
|
||||
shutdownCtx: ctx,
|
||||
shutdownFn: cancel,
|
||||
}
|
||||
s.routes()
|
||||
return s
|
||||
}
|
||||
|
||||
func chatRequest(model string) *http.Request {
|
||||
body := strings.NewReader(`{"model":"` + model + `"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func TestServer_New_GroupConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
s, err := New(config.Config{HealthCheckTimeout: 15}, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (group): %v", err)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_New_MatrixConfig(t *testing.T) {
|
||||
discard := logmon.NewWriter(io.Discard)
|
||||
cfg := config.Config{HealthCheckTimeout: 15, Matrix: &config.MatrixConfig{}}
|
||||
s, err := New(cfg, discard, discard, discard, nil, BuildInfo{})
|
||||
if err != nil {
|
||||
t.Fatalf("New (matrix): %v", err)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RouteToLocalModel(t *testing.T) {
|
||||
s := newTestServer(
|
||||
newStubRouter([]string{"local-model"}, "local response"),
|
||||
newStubRouter(nil, ""),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, chatRequest("local-model"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.String() != "local response" {
|
||||
t.Errorf("body=%q want %q", w.Body.String(), "local response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RouteToPeerModel(t *testing.T) {
|
||||
s := newTestServer(
|
||||
newStubRouter(nil, ""),
|
||||
newStubRouter([]string{"peer-model"}, "peer response"),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, chatRequest("peer-model"))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.String() != "peer response" {
|
||||
t.Errorf("body=%q want %q", w.Body.String(), "peer response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_UnknownModelReturns404(t *testing.T) {
|
||||
s := newTestServer(
|
||||
newStubRouter([]string{"local-model"}, ""),
|
||||
newStubRouter(nil, ""),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, chatRequest("unknown-model"))
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d want 404 body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_UnknownPathReturns404(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/does-not-exist", nil))
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Health(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
for _, path := range []string{"/health", "/wol-health"} {
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if w.Code != http.StatusOK || w.Body.String() != "OK" {
|
||||
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CORSPreflight(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("status=%d want 204", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("Access-Control-Allow-Origin=%q want *", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Unload(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/unload", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "OK" {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := local.unloadCalls.Load(); got != 1 {
|
||||
t.Errorf("unloadCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Running(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "")
|
||||
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{
|
||||
"m1": {
|
||||
Cmd: "llama-server",
|
||||
Proxy: "http://localhost:9999",
|
||||
UnloadAfter: 300,
|
||||
Name: "Model One",
|
||||
Description: "the first model",
|
||||
},
|
||||
}}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/running", nil))
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Running []runningModel `json:"running"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode: %v body=%q", err, w.Body.String())
|
||||
}
|
||||
if len(resp.Running) != 1 {
|
||||
t.Fatalf("running=%v want 1 entry", resp.Running)
|
||||
}
|
||||
want := runningModel{
|
||||
Model: "m1",
|
||||
State: "ready",
|
||||
Cmd: "llama-server",
|
||||
Proxy: "http://localhost:9999",
|
||||
TTL: 300,
|
||||
Name: "Model One",
|
||||
Description: "the first model",
|
||||
}
|
||||
if resp.Running[0] != want {
|
||||
t.Errorf("got %+v want %+v", resp.Running[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Preload(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "ok")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Hooks: config.HooksConfig{
|
||||
OnStartup: config.HookOnStartup{Preload: []string{"m1"}},
|
||||
}}
|
||||
|
||||
got := make(chan shared.ModelPreloadedEvent, 1)
|
||||
cancel := event.On(func(e shared.ModelPreloadedEvent) { got <- e })
|
||||
defer cancel()
|
||||
|
||||
s.startPreload()
|
||||
|
||||
select {
|
||||
case e := <-got:
|
||||
if e.ModelName != "m1" || !e.Success {
|
||||
t.Errorf("event=%+v want {ModelName:m1 Success:true}", e)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("preload event not received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Shutdown_StopsRoutersAndIsIdempotent(t *testing.T) {
|
||||
local := newStubRouter([]string{"local-model"}, "")
|
||||
peer := newStubRouter(nil, "")
|
||||
s := newTestServer(local, peer)
|
||||
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("Shutdown: %v", err)
|
||||
}
|
||||
if err := s.Shutdown(time.Second); err != nil {
|
||||
t.Fatalf("second Shutdown: %v", err)
|
||||
}
|
||||
if got := local.shutdownCalls.Load(); got != 1 {
|
||||
t.Errorf("local shutdownCalls=%d want 1", got)
|
||||
}
|
||||
if got := peer.shutdownCalls.Load(); got != 1 {
|
||||
t.Errorf("peer shutdownCalls=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_LogStream_ModelID(t *testing.T) {
|
||||
buf := logmon.NewWriter(io.Discard)
|
||||
buf.Write([]byte("hello from model"))
|
||||
|
||||
local := newStubRouter([]string{"mymodel"}, "")
|
||||
local.loggers = map[string]*logmon.Monitor{"mymodel": buf}
|
||||
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{Models: map[string]config.ModelConfig{"mymodel": {}}}
|
||||
|
||||
// Pre-cancel the context so the streaming loop exits immediately after
|
||||
// flushing history.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/stream/mymodel", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
if got := w.Body.String(); got != "hello from model" {
|
||||
t.Errorf("body=%q want %q", got, "hello from model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_LogStream_UnknownID_Returns400(t *testing.T) {
|
||||
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/logs/stream/no-such-model", nil))
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d want 400", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// uiStaticFS holds the embedded UI build. The build is copied into ui_dist by
|
||||
// the Makefile's `ui` target; placeholder.txt keeps the embed valid before a
|
||||
// build has run.
|
||||
//
|
||||
//go:embed ui_dist
|
||||
var uiStaticFS embed.FS
|
||||
|
||||
// uiFS is the embedded UI rooted at ui_dist.
|
||||
var uiFS = func() http.FileSystem {
|
||||
sub, err := fs.Sub(uiStaticFS, "ui_dist")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return http.FS(sub)
|
||||
}()
|
||||
|
||||
// selectEncoding chooses the best pre-compressed encoding the client accepts.
|
||||
// It returns the encoding ("br" or "gzip") and the matching file extension.
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
if strings.TrimSpace(strings.SplitN(part, ";", 2)[0]) == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// serveCompressedFile serves name from fsys, preferring a pre-compressed
|
||||
// sibling (name+".br" / name+".gz") when the client accepts it. It returns an
|
||||
// error without writing a response when name cannot be served, so callers can
|
||||
// fall back (e.g. SPA routing).
|
||||
func serveCompressedFile(fsys http.FileSystem, w http.ResponseWriter, r *http.Request, name string) error {
|
||||
if encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding")); encoding != "" {
|
||||
if cf, err := fsys.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file, err := fsys.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stat.IsDir() {
|
||||
return fs.ErrNotExist
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleUI serves the embedded SPA under /ui/.
|
||||
func (s *Server) handleUI(w http.ResponseWriter, r *http.Request) {
|
||||
serveUI(uiFS, w, r)
|
||||
}
|
||||
|
||||
// serveUI serves the SPA from fsys. Real files are served with compression
|
||||
// support; unknown paths without a file extension fall back to index.html so
|
||||
// client-side routing works.
|
||||
func serveUI(fsys http.FileSystem, w http.ResponseWriter, r *http.Request) {
|
||||
name := strings.TrimPrefix(r.URL.Path, "/ui/")
|
||||
if name == "" {
|
||||
name = "index.html"
|
||||
}
|
||||
|
||||
if err := serveCompressedFile(fsys, w, r, name); err != nil {
|
||||
if strings.Contains(path.Base(name), ".") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if err := serveCompressedFile(fsys, w, r, "index.html"); err != nil {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFavicon serves /favicon.ico from the embedded UI build.
|
||||
func (s *Server) handleFavicon(w http.ResponseWriter, r *http.Request) {
|
||||
if err := serveCompressedFile(uiFS, w, r, "favicon.ico"); err != nil {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
placeholder so //go:embed ui_dist succeeds before the UI is built
|
||||
@@ -0,0 +1,92 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
)
|
||||
|
||||
func TestServer_SelectEncoding(t *testing.T) {
|
||||
cases := []struct {
|
||||
accept string
|
||||
encoding string
|
||||
ext string
|
||||
}{
|
||||
{"", "", ""},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"gzip, deflate, br", "br", ".br"},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.8", "br", ".br"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
enc, ext := selectEncoding(c.accept)
|
||||
if enc != c.encoding || ext != c.ext {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)", c.accept, enc, ext, c.encoding, c.ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func uiTestFS() http.FileSystem {
|
||||
return http.FS(fstest.MapFS{
|
||||
"index.html": {Data: []byte("<html>app</html>")},
|
||||
"app.js": {Data: []byte("plain")},
|
||||
"app.js.br": {Data: []byte("brotli")},
|
||||
"app.js.gz": {Data: []byte("gzipped")},
|
||||
"favicon.ico": {Data: []byte("icon")},
|
||||
})
|
||||
}
|
||||
|
||||
func serveUIRequest(t *testing.T, path, acceptEncoding string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
if acceptEncoding != "" {
|
||||
req.Header.Set("Accept-Encoding", acceptEncoding)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
serveUI(uiTestFS(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_File(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/app.js", "")
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", w.Code)
|
||||
}
|
||||
if w.Body.String() != "plain" {
|
||||
t.Errorf("body = %q, want plain", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_Brotli(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/app.js", "gzip, br")
|
||||
if got := w.Header().Get("Content-Encoding"); got != "br" {
|
||||
t.Fatalf("Content-Encoding = %q, want br", got)
|
||||
}
|
||||
if w.Body.String() != "brotli" {
|
||||
t.Errorf("body = %q, want brotli", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_IndexAndRoot(t *testing.T) {
|
||||
for _, path := range []string{"/ui/", "/ui/index.html"} {
|
||||
w := serveUIRequest(t, path, "")
|
||||
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
|
||||
t.Errorf("%s: status=%d body=%q", path, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_SPAFallback(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/models", "")
|
||||
if w.Code != http.StatusOK || w.Body.String() != "<html>app</html>" {
|
||||
t.Errorf("SPA fallback: status=%d body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ServeUI_MissingFile(t *testing.T) {
|
||||
w := serveUIRequest(t, "/ui/missing.js", "")
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user