package server import ( "context" "errors" "fmt" "net/http" "strings" "sync" "sync/atomic" "time" "github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/perf" "github.com/mostlygeek/llama-swap/internal/router" ) // Server owns the HTTP mux, cross-cutting middleware, and the local/peer model // dispatch. It supersedes router.Server: it builds the local and peer routers // directly and dispatches between them itself. type Server struct { cfg config.Config muxlog *logmon.Monitor proxylog *logmon.Monitor upstreamlog *logmon.Monitor perf *perf.Monitor inflight *inflightCounter metrics *metricsMonitor build BuildInfo local router.LocalRouter peer router.Router mux *http.ServeMux handler http.Handler shutdownCtx context.Context shutdownFn context.CancelFunc shuttingDown atomic.Bool } // modelPostJSONRoutes are endpoints with a model id in the JSON request body. var modelPostJSONRoutes = []string{ "/v1/chat/completions", "/v1/responses", "/v1/completions", "/v1/messages", "/v1/messages/count_tokens", "/v1/embeddings", "/reranking", "/rerank", "/v1/rerank", "/v1/reranking", "/infill", "/completion", "/v1/audio/speech", "/v1/audio/voices", "/v1/images/generations", "/sdapi/v1/txt2img", "/sdapi/v1/img2img", // versionless routes, the /v/ is stripped before the request is forwarded upstream // see issue #728 "/v/chat/completions", "/v/responses", "/v/completions", "/v/messages", "/v/messages/count_tokens", "/v/embeddings", "/v/rerank", "/v/reranking", } // modelPostFormRoutes are multipart/form-data endpoints with a model id in the form data var modelPostFormRoutes = []string{ "/v1/audio/transcriptions", "/v1/images/edits", } // modelGetRoutes are model-dispatched GET endpoints (the model arrives as a // query parameter). var modelGetRoutes = []string{ "/v1/audio/voices", "/sdapi/v1/loras", } // BuildInfo carries version metadata surfaced by GET /api/version. type BuildInfo struct { Version string Commit string Date string } func New(cfg config.Config, muxlog *logmon.Monitor, proxylog *logmon.Monitor, upstreamlog *logmon.Monitor, perfMon *perf.Monitor, build BuildInfo) (*Server, error) { var local router.LocalRouter var err error switch cfg.Routing.Router.Use { case "matrix": local, err = router.NewMatrix(cfg, proxylog, upstreamlog) if err != nil { return nil, fmt.Errorf("creating matrix router: %w", err) } default: // "group" local, err = router.NewGroup(cfg, proxylog, upstreamlog) if err != nil { return nil, fmt.Errorf("creating group router: %w", err) } } peer, err := router.NewPeer(cfg, proxylog) if err != nil { return nil, fmt.Errorf("creating peer router: %w", err) } shutdownCtx, shutdownFn := context.WithCancel(context.Background()) s := &Server{ cfg: cfg, muxlog: muxlog, proxylog: proxylog, upstreamlog: upstreamlog, perf: perfMon, inflight: &inflightCounter{}, metrics: newMetricsMonitor(proxylog, cfg.MetricsMaxInMemory, cfg.CaptureBuffer), build: build, local: local, peer: peer, shutdownCtx: shutdownCtx, shutdownFn: shutdownFn, } s.routes() s.startPreload() return s, nil } // localPeerHandler dispatches a model-routed request to the local or peer // router. The model is resolved once via router.FetchContext. func (s *Server) localPeerHandler(w http.ResponseWriter, r *http.Request) { stripVersionPrefix(r) data, err := router.FetchContext(r, s.cfg) if err != nil { router.SendError(w, r, router.ErrNoModelInContext) return } switch { case s.local.Handles(data.ModelID): s.proxylog.Debugf("dispatch: using local process for model: %s", data.ModelID) s.local.ServeHTTP(w, r) case s.peer.Handles(data.ModelID): s.proxylog.Debugf("dispatch: using peer for model: %s", data.ModelID) s.peer.ServeHTTP(w, r) default: router.SendError(w, r, router.ErrNoRouterFound) } } // stripVersionPrefix rewrites versionless /v/... requests to their /... form // before forwarding upstream (issue #728). func stripVersionPrefix(r *http.Request) { if strings.HasPrefix(r.URL.Path, "/v/") { r.URL.Path = strings.TrimPrefix(r.URL.Path, "/v") } } // routes builds the mux, registers every route, and wraps the mux with the // global CORS middleware. func (s *Server) routes() { authMW := CreateAuthMiddleware(s.cfg) filterMW := CreateFilterMiddleware(s.cfg) formFilterMW := CreateFormFilterMiddleware(s.cfg) // Model-dispatched routes get auth + per-model concurrency limiting + body // filters + in-flight tracking + token metrics. concurrencyMW rejects with // 429 before the body filters do any rewrite work. filterMW rewrites JSON // bodies and formFilterMW rewrites multipart bodies; each is a no-op for the // other's Content-Type. Both run before the metrics middleware so it buffers // the rewritten body. modelChain := chain.New( authMW, CreateConcurrencyMiddleware(s.cfg), filterMW, formFilterMW, CreateInflightMiddleware(s.inflight), CreateMetricsMiddleware(s.metrics, s.cfg), ) // Custom endpoints only need auth. apiChain := chain.New(authMW) mux := http.NewServeMux() dispatch := http.HandlerFunc(s.localPeerHandler) for _, path := range modelPostJSONRoutes { mux.Handle("POST "+path, modelChain.Then(dispatch)) } for _, path := range modelPostFormRoutes { mux.Handle("POST "+path, modelChain.Then(dispatch)) } for _, path := range modelGetRoutes { mux.Handle("GET "+path, modelChain.Then(dispatch)) } // llama-swap API + custom endpoints. mux.Handle("GET /v1/models", apiChain.ThenFunc(s.handleListModels)) mux.Handle("GET /logs", apiChain.ThenFunc(s.handleLogs)) mux.Handle("GET /logs/stream", apiChain.ThenFunc(s.handleLogStream)) mux.Handle("GET /logs/stream/{logMonitorID...}", apiChain.ThenFunc(s.handleLogStream)) mux.HandleFunc("GET /health", handleHealth) mux.HandleFunc("GET /wol-health", handleHealth) mux.HandleFunc("GET /{$}", handleRootRedirect) // Embedded UI. mux.HandleFunc("GET /ui/", s.handleUI) mux.HandleFunc("GET /favicon.ico", s.handleFavicon) // Prometheus metrics (no auth, matches the legacy endpoint). mux.HandleFunc("GET /metrics", s.handleMetrics) // Operations endpoints. mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload)) mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning)) // Upstream passthrough. mux.HandleFunc("GET /upstream", handleUpstreamRedirect) mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream)) // API group (API-key protected) consumed by the UI. mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll)) mux.Handle("POST /api/models/unload/{model...}", apiChain.ThenFunc(s.handleAPIUnloadModel)) mux.Handle("GET /api/events", apiChain.ThenFunc(s.handleAPIEvents)) mux.Handle("GET /api/metrics", apiChain.ThenFunc(s.handleAPIMetrics)) mux.Handle("GET /api/performance", apiChain.ThenFunc(s.handleAPIPerformance)) mux.Handle("GET /api/version", apiChain.ThenFunc(s.handleAPIVersion)) mux.Handle("GET /api/captures/{id}", apiChain.ThenFunc(s.handleAPICapture)) s.mux = mux s.handler = chain.New(CreateRequestLogMiddleware(s.proxylog), CreateCORSMiddleware()).Then(mux) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.handler.ServeHTTP(w, r) } // CloseStreams cancels long-lived response streams (Server-Sent Events) so a // graceful httpServer.Shutdown can drain without blocking on them. It does not // tear down routers; call Shutdown for that. Safe to call repeatedly. func (s *Server) CloseStreams() { s.shutdownFn() } // Shutdown stops the local and peer routers in parallel. It is idempotent; // repeated calls return nil without re-running shutdown. // // Callers must drain inflight HTTP requests (httpServer.Shutdown) before // calling this, otherwise inflight requests 502 when their processes are torn // down. Call CloseStreams before httpServer.Shutdown so SSE streams do not // block the drain. func (s *Server) Shutdown(timeout time.Duration) error { if !s.shuttingDown.CompareAndSwap(false, true) { return nil } s.shutdownFn() var wg sync.WaitGroup var mu sync.Mutex var errs []error for _, rt := range []router.Router{s.local, s.peer} { if rt == nil { continue } wg.Add(1) go func(rt router.Router) { defer wg.Done() if err := rt.Shutdown(timeout); err != nil { mu.Lock() errs = append(errs, err) mu.Unlock() } }(rt) } wg.Wait() return errors.Join(errs...) }