diff --git a/internal/router/base.go b/internal/router/base.go index 4b34cba5..7f7232ae 100644 --- a/internal/router/base.go +++ b/internal/router/base.go @@ -28,8 +28,7 @@ type unloadReq struct { // baseRouter owns the channels, run-loop, and process machinery shared by every // concrete router. Concrete routers embed *baseRouter and supply a -// scheduler.Factory (which captures their scheduler.Swapper) describing how -// requests are scheduled and how their eviction set is decided. baseRouter +// scheduler.Swapper describing how eviction sets are decided. baseRouter // implements scheduler.Effects so the scheduler can call back for side-effects. type baseRouter struct { name string @@ -75,8 +74,8 @@ func newBaseRouter( conf config.Config, processes map[string]process.Process, logger *logmon.Monitor, - newSched scheduler.Factory, -) *baseRouter { + planner scheduler.Swapper, +) (*baseRouter, error) { shutdownCtx, shutdownFn := context.WithCancel(context.Background()) procCtx, procCancel := context.WithCancel(context.Background()) b := &baseRouter{ @@ -96,8 +95,12 @@ func newBaseRouter( serveDoneCh: make(chan scheduler.ServeDoneEvent), runDone: make(chan struct{}), } - b.schedule = newSched(name, logger, b) - return b + sched, err := scheduler.New(conf, name, logger, planner, b) + if err != nil { + return nil, err + } + b.schedule = sched + return b, nil } func (b *baseRouter) notifyProcessed() { diff --git a/internal/router/base_test.go b/internal/router/base_test.go index 82e9e0ff..0c8d2fab 100644 --- a/internal/router/base_test.go +++ b/internal/router/base_test.go @@ -29,10 +29,10 @@ func (s *stubPlanner) OnSwapStart(string, []string) {} func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter { t.Helper() conf := config.Config{HealthCheckTimeout: 5} - b := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), - func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { - return scheduler.NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, eff) - }) + b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner) + if err != nil { + t.Fatalf("newBaseRouter: %v", err) + } b.testProcessed = make(chan struct{}, 64) go b.run() t.Cleanup(func() { diff --git a/internal/router/group.go b/internal/router/group.go index 18491701..65cae63c 100644 --- a/internal/router/group.go +++ b/internal/router/group.go @@ -6,7 +6,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" - "github.com/mostlygeek/llama-swap/internal/router/scheduler" ) type Group struct { @@ -30,10 +29,10 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group } processes := make(map[string]process.Process, len(modelToGroup)) - base := newBaseRouter("group", conf, processes, proxylog, - func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { - return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) - }) + base, err := newBaseRouter("group", conf, processes, proxylog, swapper) + if err != nil { + return nil, fmt.Errorf("creating base router: %w", err) + } for mid := range modelToGroup { modelCfg, _, ok := conf.FindConfig(mid) diff --git a/internal/router/group_test.go b/internal/router/group_test.go index 336cbf28..caa28adf 100644 --- a/internal/router/group_test.go +++ b/internal/router/group_test.go @@ -10,7 +10,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" - "github.com/mostlygeek/llama-swap/internal/router/scheduler" ) // newTestGroup builds a Group directly from the supplied processes and config, @@ -27,10 +26,10 @@ func newTestGroup(t *testing.T, conf config.Config, processes map[string]process config: conf, modelToGroup: modelToGroup, } - base := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), - func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { - return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) - }) + base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper) + if err != nil { + t.Fatalf("newBaseRouter: %v", err) + } base.testProcessed = make(chan struct{}, 64) g := &Group{baseRouter: base} go base.run() diff --git a/internal/router/matrix.go b/internal/router/matrix.go index d3812446..231bb155 100644 --- a/internal/router/matrix.go +++ b/internal/router/matrix.go @@ -6,7 +6,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" - "github.com/mostlygeek/llama-swap/internal/router/scheduler" ) type Matrix struct { @@ -27,10 +26,10 @@ func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matr // Build a process for every model in the config. Any model can run alone // even if it is not part of a set; this mirrors proxy.NewMatrix. processes := make(map[string]process.Process, len(conf.Models)) - base := newBaseRouter("matrix", conf, processes, proxylog, - func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { - return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) - }) + base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper) + if err != nil { + return nil, fmt.Errorf("creating base router: %w", err) + } for mid, modelCfg := range conf.Models { procLog := logmon.NewWriter(upstreamlog) diff --git a/internal/router/matrix_test.go b/internal/router/matrix_test.go index 0d7a985d..b093730f 100644 --- a/internal/router/matrix_test.go +++ b/internal/router/matrix_test.go @@ -10,7 +10,6 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" - "github.com/mostlygeek/llama-swap/internal/router/scheduler" ) // newTestMatrix builds a Matrix router from supplied processes, bypassing @@ -22,10 +21,10 @@ func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedS solver: newMatrixSolver(expanded, evictCosts), logger: logger, } - base := newBaseRouter("matrix", conf, processes, logger, - func(name string, l *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { - return scheduler.NewFIFO(name, l, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) - }) + base, err := newBaseRouter("matrix", conf, processes, logger, swapper) + if err != nil { + t.Fatalf("newBaseRouter: %v", err) + } base.testProcessed = make(chan struct{}, 64) r := &Matrix{baseRouter: base} go base.run() diff --git a/internal/router/scheduler/fifo.go b/internal/router/scheduler/fifo.go index 49e050b5..addd3b94 100644 --- a/internal/router/scheduler/fifo.go +++ b/internal/router/scheduler/fifo.go @@ -8,8 +8,13 @@ import ( "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" + "github.com/mostlygeek/llama-swap/internal/shared" ) +// defaultConcurrencyLimit caps simultaneous in-flight requests per model when +// the model config leaves concurrencyLimit unset. +const defaultConcurrencyLimit = 10 + // activeSwap tracks one in-flight swap and the callers waiting on it. type activeSwap struct { modelID string @@ -33,20 +38,32 @@ type FIFO struct { cfg config.FifoConfig effects Effects + limits map[string]int active map[string]*activeSwap inFlight map[string]int queued []HandlerReq } -// NewFIFO builds a FIFO scheduler. It matches scheduler.Factory once a planner -// is captured in a closure. -func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, eff Effects) *FIFO { +// NewFIFO builds a FIFO scheduler. Per-model concurrency limits are derived +// from models: each model's ConcurrencyLimit overrides defaultConcurrencyLimit +// when set to a value greater than zero. +func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, models map[string]config.ModelConfig, eff Effects) *FIFO { + limits := make(map[string]int, len(models)) + for id, mc := range models { + limit := defaultConcurrencyLimit + if mc.ConcurrencyLimit > 0 { + limit = mc.ConcurrencyLimit + } + limits[id] = limit + } + return &FIFO{ name: name, logger: logger, planner: planner, cfg: cfg, effects: eff, + limits: limits, active: make(map[string]*activeSwap), inFlight: make(map[string]int), } @@ -254,12 +271,27 @@ func (s *FIFO) OnShutdown(err error) { // grantHandler hands the caller a tracked handler for modelID and, only if the // caller was still there to receive it, bumps the in-flight count. Incrementing // when the grant failed would strand the counter and block future evictions. +// Requests that would exceed the model's concurrency limit are rejected with a +// shared.NewConcurrencyLimitError (HTTP 429 with Retry-After). func (s *FIFO) grantHandler(req HandlerReq, modelID string) { + if s.inFlight[modelID] >= s.limit(modelID) { + s.effects.GrantError(req, shared.ConcurrencyLimitError{}) + return + } if s.effects.GrantServe(req, modelID) { s.inFlight[modelID]++ } } +// limit returns the per-model concurrency cap, defaulting to +// defaultConcurrencyLimit when the model has no explicit entry. +func (s *FIFO) limit(modelID string) int { + if l, ok := s.limits[modelID]; ok { + return l + } + return defaultConcurrencyLimit +} + // startSwap records the swap as active and launches it via Effects. running is // the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against // the same picture it decided on. diff --git a/internal/router/scheduler/fifo_test.go b/internal/router/scheduler/fifo_test.go index f0a6e00d..6d177bb2 100644 --- a/internal/router/scheduler/fifo_test.go +++ b/internal/router/scheduler/fifo_test.go @@ -3,12 +3,14 @@ package scheduler import ( "errors" "io" + "net/http" "testing" "time" "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" + "github.com/mostlygeek/llama-swap/internal/shared" ) // FIFO methods all run on the router's single run-loop goroutine, so these @@ -138,7 +140,7 @@ func (f *fakeEffects) startsFor(modelID string) int { } func newFIFO(planner Swapper, eff Effects) *FIFO { - return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, eff) + return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, nil, eff) } func req(model string) HandlerReq { return HandlerReq{Model: model} } @@ -521,7 +523,7 @@ func TestFIFO_PriorityQueueOrder(t *testing.T) { // loading collides with z's in-flight swap and parks in the queue. planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}} cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}} - s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, eff) + s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, nil, eff) s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D]) @@ -631,3 +633,123 @@ func TestFIFO_OnCancel_NotPresent(t *testing.T) { t.Errorf("queue should be empty, len=%d", len(s.queued)) } } + +// newFIFOWithLimit builds a FIFO whose single model has the given concurrency +// limit, already in StateReady so every request exercises the fast path. +func newFIFOWithLimit(t *testing.T, model string, limit int) (*FIFO, *fakeEffects) { + t.Helper() + eff := newFakeEffects() + eff.states[model] = process.StateReady + models := map[string]config.ModelConfig{ + model: {ConcurrencyLimit: limit}, + } + s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff) + return s, eff +} + +// TestFIFO_ConcurrencyLimit_RejectsOverLimit verifies that a request arriving +// while the model is at capacity gets an error grant instead of being served, +// and that a new request succeeds once an in-flight one completes. +func TestFIFO_ConcurrencyLimit_RejectsOverLimit(t *testing.T) { + s, eff := newFIFOWithLimit(t, "a", 1) + + // First request: served (inFlight 0 → 1). + s.OnRequest(req("a")) + if got := eff.served("a"); got != 1 { + t.Fatalf("served(a)=%d want 1", got) + } + + // Second request while slot is occupied: rejected with HTTPError 429. + s.OnRequest(req("a")) + if got := eff.errored("a"); got != 1 { + t.Fatalf("errored(a)=%d want 1 (over-limit)", got) + } + var httpErr shared.HTTPError + if !errors.As(eff.grants[len(eff.grants)-1].err, &httpErr) { + t.Fatalf("err=%v want HTTPError", eff.grants[len(eff.grants)-1].err) + } + if httpErr.StatusCode() != http.StatusTooManyRequests { + t.Fatalf("StatusCode()=%d want 429", httpErr.StatusCode()) + } + if httpErr.Header().Get("Retry-After") == "" { + t.Fatal("missing Retry-After header") + } + + // After the in-flight request finishes, a new request succeeds. + s.OnServeDone(ServeDoneEvent{ModelID: "a"}) + s.OnRequest(req("a")) + if got := eff.served("a"); got != 2 { + t.Fatalf("served(a)=%d want 2 after drain", got) + } +} + +// TestFIFO_ConcurrencyLimit_DefaultIsTen verifies that a model without an +// explicit ConcurrencyLimit gets the default cap of 10. +func TestFIFO_ConcurrencyLimit_DefaultIsTen(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateReady + // nil models → every model gets defaultConcurrencyLimit (10). + s := newFIFO(&stubPlanner{}, eff) + + for i := 0; i < 10; i++ { + s.OnRequest(req("a")) + } + if got := eff.served("a"); got != 10 { + t.Fatalf("served(a)=%d want 10 (default limit)", got) + } + + // 11th request is rejected. + s.OnRequest(req("a")) + if got := eff.errored("a"); got != 1 { + t.Fatalf("errored(a)=%d want 1 (over default limit)", got) + } +} + +// TestFIFO_ConcurrencyLimit_CustomLimit verifies a ConcurrencyLimit greater +// than zero overrides the default. +func TestFIFO_ConcurrencyLimit_CustomLimit(t *testing.T) { + s, eff := newFIFOWithLimit(t, "a", 2) + + s.OnRequest(req("a")) + s.OnRequest(req("a")) + s.OnRequest(req("a")) + + if got := eff.served("a"); got != 2 { + t.Fatalf("served(a)=%d want 2 (custom limit)", got) + } + if got := eff.errored("a"); got != 1 { + t.Fatalf("errored(a)=%d want 1 (over custom limit)", got) + } +} + +// TestFIFO_ConcurrencyLimit_SwapWaiters verifies that when more swap waiters +// exist than the concurrency limit, excess waiters are rejected on swap +// completion rather than exceeding the limit. +func TestFIFO_ConcurrencyLimit_SwapWaiters(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + models := map[string]config.ModelConfig{ + "a": {ConcurrencyLimit: 2}, + } + s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff) + + // Three requests arrive while model is loading: one starts swap, two join. + s.OnRequest(req("a")) + s.OnRequest(req("a")) + s.OnRequest(req("a")) + + if got := eff.startsFor("a"); got != 1 { + t.Fatalf("StartSwap(a)=%d want 1", got) + } + + // Swap completes: two served (limit), one rejected. + eff.states["a"] = process.StateReady + s.OnSwapDone(SwapDone{ModelID: "a"}) + + if got := eff.served("a"); got != 2 { + t.Fatalf("served(a)=%d want 2 (limit on swap completion)", got) + } + if got := eff.errored("a"); got != 1 { + t.Fatalf("errored(a)=%d want 1 (excess waiter rejected)", got) + } +} diff --git a/internal/router/scheduler/scheduler.go b/internal/router/scheduler/scheduler.go index eda36574..8b87db3b 100644 --- a/internal/router/scheduler/scheduler.go +++ b/internal/router/scheduler/scheduler.go @@ -11,9 +11,11 @@ package scheduler import ( "context" + "fmt" "net/http" "time" + "github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/shared" @@ -90,9 +92,21 @@ type Effects interface { StopProcesses(timeout time.Duration, ids []string) } -// Factory builds a Scheduler bound to a baseRouter's Effects. The concrete -// router captures its Swapper in the closure it passes as a Factory. -type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler +// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured +// from conf and bound to the given planner and effects. Currently only "fifo" +// (the default) is supported. +func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) { + use := conf.Routing.Scheduler.Use + if use == "" { + use = "fifo" + } + switch use { + case "fifo": + return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil + default: + return nil, fmt.Errorf("unsupported scheduler type: %q", use) + } +} // HandlerReq is one in-flight ServeHTTP request waiting for a routing decision. type HandlerReq struct { diff --git a/internal/server/concurrency.go b/internal/server/concurrency.go deleted file mode 100644 index ccc339f3..00000000 --- a/internal/server/concurrency.go +++ /dev/null @@ -1,57 +0,0 @@ -package server - -import ( - "net/http" - - "golang.org/x/sync/semaphore" - - "github.com/mostlygeek/llama-swap/internal/chain" - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/shared" -) - -// defaultConcurrencyLimit caps simultaneous in-flight requests per model when -// the model config leaves concurrencyLimit unset. Matches the legacy -// proxy.Process default. -const defaultConcurrencyLimit = 10 - -// CreateConcurrencyMiddleware returns middleware that limits simultaneous -// model-dispatched requests per model. Each model gets a semaphore sized to -// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot -// immediately acquire a slot is rejected with 429. Models without a local -// config entry (e.g. peer-routed models) are not limited. -func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware { - semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models)) - for id, mc := range cfg.Models { - limit := defaultConcurrencyLimit - if mc.ConcurrencyLimit > 0 { - limit = mc.ConcurrencyLimit - } - semaphores[id] = semaphore.NewWeighted(int64(limit)) - } - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - data, err := shared.FetchContext(r, cfg) - if err != nil { - shared.SendError(w, r, shared.ErrNoModelInContext) - return - } - - // fall through for peer models - sem, ok := semaphores[data.ModelID] - if !ok { - next.ServeHTTP(w, r) - return - } - if !sem.TryAcquire(1) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - w.Write([]byte(`{"error":"Too many requests"}`)) - return - } - defer sem.Release(1) - next.ServeHTTP(w, r) - }) - } -} diff --git a/internal/server/concurrency_test.go b/internal/server/concurrency_test.go deleted file mode 100644 index 9cc68f97..00000000 --- a/internal/server/concurrency_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package server - -import ( - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/mostlygeek/llama-swap/internal/config" - "github.com/mostlygeek/llama-swap/internal/shared" -) - -func concurrencyTestReq(model string) *http.Request { - r := httptest.NewRequest("GET", "/v1/chat/completions", nil) - return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model})) -} - -func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) { - cfg := config.Config{ - Models: map[string]config.ModelConfig{ - "m1": {ConcurrencyLimit: 1}, - }, - } - - entered := make(chan struct{}) - release := make(chan struct{}) - var once sync.Once - final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - once.Do(func() { close(entered) }) - <-release - w.WriteHeader(http.StatusOK) - }) - h := CreateConcurrencyMiddleware(cfg)(final) - - // First request occupies the only slot. - done := make(chan struct{}) - go func() { - defer close(done) - h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1")) - }() - <-entered - - // Second concurrent request is rejected with 429. - w := httptest.NewRecorder() - h.ServeHTTP(w, concurrencyTestReq("m1")) - if w.Code != http.StatusTooManyRequests { - t.Fatalf("over-limit status = %d, want 429", w.Code) - } - - // Once the slot frees, a new request succeeds. - close(release) - <-done - w = httptest.NewRecorder() - h.ServeHTTP(w, concurrencyTestReq("m1")) - if w.Code != http.StatusOK { - t.Fatalf("post-release status = %d, want 200", w.Code) - } -} - -func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) { - cfg := config.Config{Models: map[string]config.ModelConfig{}} - - called := 0 - final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called++ - w.WriteHeader(http.StatusOK) - }) - h := CreateConcurrencyMiddleware(cfg)(final) - - w := httptest.NewRecorder() - h.ServeHTTP(w, concurrencyTestReq("peer-model")) - if w.Code != http.StatusOK || called != 1 { - t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called) - } -} diff --git a/internal/server/server.go b/internal/server/server.go index f2ad15ab..739e0c0d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -177,7 +177,6 @@ func (s *Server) routes() { modelChain := chain.New( authMW, CreateRequestContextMiddleware(s.cfg), - CreateConcurrencyMiddleware(s.cfg), CreateFilterMiddleware(s.cfg), CreateFormFilterMiddleware(s.cfg), CreateInflightMiddleware(s.inflight), diff --git a/internal/shared/http.go b/internal/shared/http.go index adc19e87..a0cb063b 100644 --- a/internal/shared/http.go +++ b/internal/shared/http.go @@ -37,6 +37,16 @@ var ( ) func SendError(w http.ResponseWriter, r *http.Request, err error) { + var httpErr HTTPError + if errors.As(err, &httpErr) { + for k, v := range httpErr.Header() { + w.Header()[k] = v + } + w.WriteHeader(httpErr.StatusCode()) + w.Write(httpErr.Body()) + return + } + switch { case errors.Is(err, ErrNoModelInContext): SendResponse(w, r, http.StatusNotFound, "no model id could be identified") diff --git a/internal/shared/httperror.go b/internal/shared/httperror.go new file mode 100644 index 00000000..e46a908b --- /dev/null +++ b/internal/shared/httperror.go @@ -0,0 +1,63 @@ +package shared + +import ( + "encoding/json" + "net/http" + "strconv" +) + +// HTTPError is an error that carries a complete HTTP response. A producer (e.g. +// a scheduler shedding a request) returns one of these; a renderer (e.g. +// router.SendError) writes the status, headers, and body verbatim instead of +// mapping the error to a generic status. It is the seam that lets a component +// shed a request with a rich response (e.g. a 429 with rate-limit headers and a +// JSON hint body) without the renderer knowing the producer's internals. +type HTTPError interface { + error + StatusCode() int + Header() http.Header + Body() []byte +} + +// ConcurrencyLimitError is an HTTPError for a 429 concurrency-limit rejection. +// Zero-value fields fall back to sensible defaults: a 1-second Retry-After and a +// JSON hint body. +type ConcurrencyLimitError struct { + // RetryAfter, when > 0, is sent as the Retry-After header (in seconds). + // Defaults to 1. + RetryAfter int + + // Message overrides the JSON body's "error" field. Defaults to + // "Too many requests". + Message string +} + +func (e ConcurrencyLimitError) Error() string { return "concurrency limit reached" } + +func (e ConcurrencyLimitError) StatusCode() int { return http.StatusTooManyRequests } + +func (e ConcurrencyLimitError) Header() http.Header { + h := http.Header{} + h.Set("Content-Type", "application/json") + h.Set("Retry-After", e.retryAfter()) + return h +} + +func (e ConcurrencyLimitError) Body() []byte { + b, _ := json.Marshal(map[string]string{"error": e.message()}) + return b +} + +func (e ConcurrencyLimitError) retryAfter() string { + if e.RetryAfter > 0 { + return strconv.Itoa(e.RetryAfter) + } + return "1" +} + +func (e ConcurrencyLimitError) message() string { + if e.Message != "" { + return e.Message + } + return "Too many requests" +}