schedule,shared: move concurrency 429 limits into scheduler code (#849)

- make concurrency limiting the scheduler.Scheduler's responsibility
- eliminate the separate concurrency limit middleware 
- move concurrencyLimit logic into scheduler.FIFO to maintain backwards compatibility
- add HTTPError from #834 

Updates #834
This commit is contained in:
Benson Wong
2026-06-15 22:35:12 -07:00
committed by GitHub
parent 8e84b2ec4f
commit 6cf1317341
14 changed files with 278 additions and 171 deletions
+9 -6
View File
@@ -28,8 +28,7 @@ type unloadReq struct {
// baseRouter owns the channels, run-loop, and process machinery shared by every // baseRouter owns the channels, run-loop, and process machinery shared by every
// concrete router. Concrete routers embed *baseRouter and supply a // concrete router. Concrete routers embed *baseRouter and supply a
// scheduler.Factory (which captures their scheduler.Swapper) describing how // scheduler.Swapper describing how eviction sets are decided. baseRouter
// requests are scheduled and how their eviction set is decided. baseRouter
// implements scheduler.Effects so the scheduler can call back for side-effects. // implements scheduler.Effects so the scheduler can call back for side-effects.
type baseRouter struct { type baseRouter struct {
name string name string
@@ -75,8 +74,8 @@ func newBaseRouter(
conf config.Config, conf config.Config,
processes map[string]process.Process, processes map[string]process.Process,
logger *logmon.Monitor, logger *logmon.Monitor,
newSched scheduler.Factory, planner scheduler.Swapper,
) *baseRouter { ) (*baseRouter, error) {
shutdownCtx, shutdownFn := context.WithCancel(context.Background()) shutdownCtx, shutdownFn := context.WithCancel(context.Background())
procCtx, procCancel := context.WithCancel(context.Background()) procCtx, procCancel := context.WithCancel(context.Background())
b := &baseRouter{ b := &baseRouter{
@@ -96,8 +95,12 @@ func newBaseRouter(
serveDoneCh: make(chan scheduler.ServeDoneEvent), serveDoneCh: make(chan scheduler.ServeDoneEvent),
runDone: make(chan struct{}), runDone: make(chan struct{}),
} }
b.schedule = newSched(name, logger, b) sched, err := scheduler.New(conf, name, logger, planner, b)
return b if err != nil {
return nil, err
}
b.schedule = sched
return b, nil
} }
func (b *baseRouter) notifyProcessed() { func (b *baseRouter) notifyProcessed() {
+4 -4
View File
@@ -29,10 +29,10 @@ func (s *stubPlanner) OnSwapStart(string, []string) {}
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter { func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
t.Helper() t.Helper()
conf := config.Config{HealthCheckTimeout: 5} conf := config.Config{HealthCheckTimeout: 5}
b := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, eff) t.Fatalf("newBaseRouter: %v", err)
}) }
b.testProcessed = make(chan struct{}, 64) b.testProcessed = make(chan struct{}, 64)
go b.run() go b.run()
t.Cleanup(func() { t.Cleanup(func() {
+4 -5
View File
@@ -6,7 +6,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
type Group struct { type Group struct {
@@ -30,10 +29,10 @@ func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group
} }
processes := make(map[string]process.Process, len(modelToGroup)) processes := make(map[string]process.Process, len(modelToGroup))
base := newBaseRouter("group", conf, processes, proxylog, base, err := newBaseRouter("group", conf, processes, proxylog, swapper)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) return nil, fmt.Errorf("creating base router: %w", err)
}) }
for mid := range modelToGroup { for mid := range modelToGroup {
modelCfg, _, ok := conf.FindConfig(mid) modelCfg, _, ok := conf.FindConfig(mid)
+4 -5
View File
@@ -10,7 +10,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
// newTestGroup builds a Group directly from the supplied processes and config, // newTestGroup builds a Group directly from the supplied processes and config,
@@ -27,10 +26,10 @@ func newTestGroup(t *testing.T, conf config.Config, processes map[string]process
config: conf, config: conf,
modelToGroup: modelToGroup, modelToGroup: modelToGroup,
} }
base := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) t.Fatalf("newBaseRouter: %v", err)
}) }
base.testProcessed = make(chan struct{}, 64) base.testProcessed = make(chan struct{}, 64)
g := &Group{baseRouter: base} g := &Group{baseRouter: base}
go base.run() go base.run()
+4 -5
View File
@@ -6,7 +6,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
type Matrix struct { type Matrix struct {
@@ -27,10 +26,10 @@ func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matr
// Build a process for every model in the config. Any model can run alone // Build a process for every model in the config. Any model can run alone
// even if it is not part of a set; this mirrors proxy.NewMatrix. // even if it is not part of a set; this mirrors proxy.NewMatrix.
processes := make(map[string]process.Process, len(conf.Models)) processes := make(map[string]process.Process, len(conf.Models))
base := newBaseRouter("matrix", conf, processes, proxylog, base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper)
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, logger, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) return nil, fmt.Errorf("creating base router: %w", err)
}) }
for mid, modelCfg := range conf.Models { for mid, modelCfg := range conf.Models {
procLog := logmon.NewWriter(upstreamlog) procLog := logmon.NewWriter(upstreamlog)
+4 -5
View File
@@ -10,7 +10,6 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
) )
// newTestMatrix builds a Matrix router from supplied processes, bypassing // newTestMatrix builds a Matrix router from supplied processes, bypassing
@@ -22,10 +21,10 @@ func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedS
solver: newMatrixSolver(expanded, evictCosts), solver: newMatrixSolver(expanded, evictCosts),
logger: logger, logger: logger,
} }
base := newBaseRouter("matrix", conf, processes, logger, base, err := newBaseRouter("matrix", conf, processes, logger, swapper)
func(name string, l *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler { if err != nil {
return scheduler.NewFIFO(name, l, swapper, conf.Routing.Scheduler.Settings.Fifo, eff) t.Fatalf("newBaseRouter: %v", err)
}) }
base.testProcessed = make(chan struct{}, 64) base.testProcessed = make(chan struct{}, 64)
r := &Matrix{baseRouter: base} r := &Matrix{baseRouter: base}
go base.run() go base.run()
+35 -3
View File
@@ -8,8 +8,13 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared"
) )
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
// the model config leaves concurrencyLimit unset.
const defaultConcurrencyLimit = 10
// activeSwap tracks one in-flight swap and the callers waiting on it. // activeSwap tracks one in-flight swap and the callers waiting on it.
type activeSwap struct { type activeSwap struct {
modelID string modelID string
@@ -33,20 +38,32 @@ type FIFO struct {
cfg config.FifoConfig cfg config.FifoConfig
effects Effects effects Effects
limits map[string]int
active map[string]*activeSwap active map[string]*activeSwap
inFlight map[string]int inFlight map[string]int
queued []HandlerReq queued []HandlerReq
} }
// NewFIFO builds a FIFO scheduler. It matches scheduler.Factory once a planner // NewFIFO builds a FIFO scheduler. Per-model concurrency limits are derived
// is captured in a closure. // from models: each model's ConcurrencyLimit overrides defaultConcurrencyLimit
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, eff Effects) *FIFO { // when set to a value greater than zero.
func NewFIFO(name string, logger *logmon.Monitor, planner Swapper, cfg config.FifoConfig, models map[string]config.ModelConfig, eff Effects) *FIFO {
limits := make(map[string]int, len(models))
for id, mc := range models {
limit := defaultConcurrencyLimit
if mc.ConcurrencyLimit > 0 {
limit = mc.ConcurrencyLimit
}
limits[id] = limit
}
return &FIFO{ return &FIFO{
name: name, name: name,
logger: logger, logger: logger,
planner: planner, planner: planner,
cfg: cfg, cfg: cfg,
effects: eff, effects: eff,
limits: limits,
active: make(map[string]*activeSwap), active: make(map[string]*activeSwap),
inFlight: make(map[string]int), inFlight: make(map[string]int),
} }
@@ -254,12 +271,27 @@ func (s *FIFO) OnShutdown(err error) {
// grantHandler hands the caller a tracked handler for modelID and, only if the // grantHandler hands the caller a tracked handler for modelID and, only if the
// caller was still there to receive it, bumps the in-flight count. Incrementing // caller was still there to receive it, bumps the in-flight count. Incrementing
// when the grant failed would strand the counter and block future evictions. // when the grant failed would strand the counter and block future evictions.
// Requests that would exceed the model's concurrency limit are rejected with a
// shared.NewConcurrencyLimitError (HTTP 429 with Retry-After).
func (s *FIFO) grantHandler(req HandlerReq, modelID string) { func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
if s.inFlight[modelID] >= s.limit(modelID) {
s.effects.GrantError(req, shared.ConcurrencyLimitError{})
return
}
if s.effects.GrantServe(req, modelID) { if s.effects.GrantServe(req, modelID) {
s.inFlight[modelID]++ s.inFlight[modelID]++
} }
} }
// limit returns the per-model concurrency cap, defaulting to
// defaultConcurrencyLimit when the model has no explicit entry.
func (s *FIFO) limit(modelID string) int {
if l, ok := s.limits[modelID]; ok {
return l
}
return defaultConcurrencyLimit
}
// startSwap records the swap as active and launches it via Effects. running is // startSwap records the swap as active and launches it via Effects. running is
// the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against // the set EvictionFor saw, forwarded to OnSwapStart so the planner logs against
// the same picture it decided on. // the same picture it decided on.
+124 -2
View File
@@ -3,12 +3,14 @@ package scheduler
import ( import (
"errors" "errors"
"io" "io"
"net/http"
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared"
) )
// FIFO methods all run on the router's single run-loop goroutine, so these // FIFO methods all run on the router's single run-loop goroutine, so these
@@ -138,7 +140,7 @@ func (f *fakeEffects) startsFor(modelID string) int {
} }
func newFIFO(planner Swapper, eff Effects) *FIFO { func newFIFO(planner Swapper, eff Effects) *FIFO {
return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, eff) return NewFIFO("test", logmon.NewWriter(io.Discard), planner, config.FifoConfig{}, nil, eff)
} }
func req(model string) HandlerReq { return HandlerReq{Model: model} } func req(model string) HandlerReq { return HandlerReq{Model: model} }
@@ -521,7 +523,7 @@ func TestFIFO_PriorityQueueOrder(t *testing.T) {
// loading collides with z's in-flight swap and parks in the queue. // loading collides with z's in-flight swap and parks in the queue.
planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}} planner := &stubPlanner{evict: map[string][]string{"z": {"A", "B", "C", "D"}}}
cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}} cfg := config.FifoConfig{Priority: map[string]int{"A": 10, "B": 5, "C": 5, "D": 1}}
s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, eff) s := NewFIFO("test", logmon.NewWriter(io.Discard), planner, cfg, nil, eff)
s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D]) s.OnRequest(req("z")) // StartSwap(z, [A,B,C,D])
@@ -631,3 +633,123 @@ func TestFIFO_OnCancel_NotPresent(t *testing.T) {
t.Errorf("queue should be empty, len=%d", len(s.queued)) t.Errorf("queue should be empty, len=%d", len(s.queued))
} }
} }
// newFIFOWithLimit builds a FIFO whose single model has the given concurrency
// limit, already in StateReady so every request exercises the fast path.
func newFIFOWithLimit(t *testing.T, model string, limit int) (*FIFO, *fakeEffects) {
t.Helper()
eff := newFakeEffects()
eff.states[model] = process.StateReady
models := map[string]config.ModelConfig{
model: {ConcurrencyLimit: limit},
}
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
return s, eff
}
// TestFIFO_ConcurrencyLimit_RejectsOverLimit verifies that a request arriving
// while the model is at capacity gets an error grant instead of being served,
// and that a new request succeeds once an in-flight one completes.
func TestFIFO_ConcurrencyLimit_RejectsOverLimit(t *testing.T) {
s, eff := newFIFOWithLimit(t, "a", 1)
// First request: served (inFlight 0 → 1).
s.OnRequest(req("a"))
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1", got)
}
// Second request while slot is occupied: rejected with HTTPError 429.
s.OnRequest(req("a"))
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (over-limit)", got)
}
var httpErr shared.HTTPError
if !errors.As(eff.grants[len(eff.grants)-1].err, &httpErr) {
t.Fatalf("err=%v want HTTPError", eff.grants[len(eff.grants)-1].err)
}
if httpErr.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("StatusCode()=%d want 429", httpErr.StatusCode())
}
if httpErr.Header().Get("Retry-After") == "" {
t.Fatal("missing Retry-After header")
}
// After the in-flight request finishes, a new request succeeds.
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
s.OnRequest(req("a"))
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2 after drain", got)
}
}
// TestFIFO_ConcurrencyLimit_DefaultIsTen verifies that a model without an
// explicit ConcurrencyLimit gets the default cap of 10.
func TestFIFO_ConcurrencyLimit_DefaultIsTen(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
// nil models → every model gets defaultConcurrencyLimit (10).
s := newFIFO(&stubPlanner{}, eff)
for i := 0; i < 10; i++ {
s.OnRequest(req("a"))
}
if got := eff.served("a"); got != 10 {
t.Fatalf("served(a)=%d want 10 (default limit)", got)
}
// 11th request is rejected.
s.OnRequest(req("a"))
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (over default limit)", got)
}
}
// TestFIFO_ConcurrencyLimit_CustomLimit verifies a ConcurrencyLimit greater
// than zero overrides the default.
func TestFIFO_ConcurrencyLimit_CustomLimit(t *testing.T) {
s, eff := newFIFOWithLimit(t, "a", 2)
s.OnRequest(req("a"))
s.OnRequest(req("a"))
s.OnRequest(req("a"))
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2 (custom limit)", got)
}
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (over custom limit)", got)
}
}
// TestFIFO_ConcurrencyLimit_SwapWaiters verifies that when more swap waiters
// exist than the concurrency limit, excess waiters are rejected on swap
// completion rather than exceeding the limit.
func TestFIFO_ConcurrencyLimit_SwapWaiters(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
models := map[string]config.ModelConfig{
"a": {ConcurrencyLimit: 2},
}
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, config.FifoConfig{}, models, eff)
// Three requests arrive while model is loading: one starts swap, two join.
s.OnRequest(req("a"))
s.OnRequest(req("a"))
s.OnRequest(req("a"))
if got := eff.startsFor("a"); got != 1 {
t.Fatalf("StartSwap(a)=%d want 1", got)
}
// Swap completes: two served (limit), one rejected.
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"})
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2 (limit on swap completion)", got)
}
if got := eff.errored("a"); got != 1 {
t.Fatalf("errored(a)=%d want 1 (excess waiter rejected)", got)
}
}
+17 -3
View File
@@ -11,9 +11,11 @@ package scheduler
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process" "github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared" "github.com/mostlygeek/llama-swap/internal/shared"
@@ -90,9 +92,21 @@ type Effects interface {
StopProcesses(timeout time.Duration, ids []string) StopProcesses(timeout time.Duration, ids []string)
} }
// Factory builds a Scheduler bound to a baseRouter's Effects. The concrete // New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
// router captures its Swapper in the closure it passes as a Factory. // from conf and bound to the given planner and effects. Currently only "fifo"
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler // (the default) is supported.
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
use := conf.Routing.Scheduler.Use
if use == "" {
use = "fifo"
}
switch use {
case "fifo":
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
default:
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
}
}
// HandlerReq is one in-flight ServeHTTP request waiting for a routing decision. // HandlerReq is one in-flight ServeHTTP request waiting for a routing decision.
type HandlerReq struct { type HandlerReq struct {
-57
View File
@@ -1,57 +0,0 @@
package server
import (
"net/http"
"golang.org/x/sync/semaphore"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
// the model config leaves concurrencyLimit unset. Matches the legacy
// proxy.Process default.
const defaultConcurrencyLimit = 10
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
// model-dispatched requests per model. Each model gets a semaphore sized to
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
// immediately acquire a slot is rejected with 429. Models without a local
// config entry (e.g. peer-routed models) are not limited.
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
for id, mc := range cfg.Models {
limit := defaultConcurrencyLimit
if mc.ConcurrencyLimit > 0 {
limit = mc.ConcurrencyLimit
}
semaphores[id] = semaphore.NewWeighted(int64(limit))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := shared.FetchContext(r, cfg)
if err != nil {
shared.SendError(w, r, shared.ErrNoModelInContext)
return
}
// fall through for peer models
sem, ok := semaphores[data.ModelID]
if !ok {
next.ServeHTTP(w, r)
return
}
if !sem.TryAcquire(1) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"Too many requests"}`))
return
}
defer sem.Release(1)
next.ServeHTTP(w, r)
})
}
}
-75
View File
@@ -1,75 +0,0 @@
package server
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/shared"
)
func concurrencyTestReq(model string) *http.Request {
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model}))
}
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
cfg := config.Config{
Models: map[string]config.ModelConfig{
"m1": {ConcurrencyLimit: 1},
},
}
entered := make(chan struct{})
release := make(chan struct{})
var once sync.Once
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
once.Do(func() { close(entered) })
<-release
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
// First request occupies the only slot.
done := make(chan struct{})
go func() {
defer close(done)
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
}()
<-entered
// Second concurrent request is rejected with 429.
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusTooManyRequests {
t.Fatalf("over-limit status = %d, want 429", w.Code)
}
// Once the slot frees, a new request succeeds.
close(release)
<-done
w = httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusOK {
t.Fatalf("post-release status = %d, want 200", w.Code)
}
}
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{}}
called := 0
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called++
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
if w.Code != http.StatusOK || called != 1 {
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
}
}
-1
View File
@@ -177,7 +177,6 @@ func (s *Server) routes() {
modelChain := chain.New( modelChain := chain.New(
authMW, authMW,
CreateRequestContextMiddleware(s.cfg), CreateRequestContextMiddleware(s.cfg),
CreateConcurrencyMiddleware(s.cfg),
CreateFilterMiddleware(s.cfg), CreateFilterMiddleware(s.cfg),
CreateFormFilterMiddleware(s.cfg), CreateFormFilterMiddleware(s.cfg),
CreateInflightMiddleware(s.inflight), CreateInflightMiddleware(s.inflight),
+10
View File
@@ -37,6 +37,16 @@ var (
) )
func SendError(w http.ResponseWriter, r *http.Request, err error) { func SendError(w http.ResponseWriter, r *http.Request, err error) {
var httpErr HTTPError
if errors.As(err, &httpErr) {
for k, v := range httpErr.Header() {
w.Header()[k] = v
}
w.WriteHeader(httpErr.StatusCode())
w.Write(httpErr.Body())
return
}
switch { switch {
case errors.Is(err, ErrNoModelInContext): case errors.Is(err, ErrNoModelInContext):
SendResponse(w, r, http.StatusNotFound, "no model id could be identified") SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
+63
View File
@@ -0,0 +1,63 @@
package shared
import (
"encoding/json"
"net/http"
"strconv"
)
// HTTPError is an error that carries a complete HTTP response. A producer (e.g.
// a scheduler shedding a request) returns one of these; a renderer (e.g.
// router.SendError) writes the status, headers, and body verbatim instead of
// mapping the error to a generic status. It is the seam that lets a component
// shed a request with a rich response (e.g. a 429 with rate-limit headers and a
// JSON hint body) without the renderer knowing the producer's internals.
type HTTPError interface {
error
StatusCode() int
Header() http.Header
Body() []byte
}
// ConcurrencyLimitError is an HTTPError for a 429 concurrency-limit rejection.
// Zero-value fields fall back to sensible defaults: a 1-second Retry-After and a
// JSON hint body.
type ConcurrencyLimitError struct {
// RetryAfter, when > 0, is sent as the Retry-After header (in seconds).
// Defaults to 1.
RetryAfter int
// Message overrides the JSON body's "error" field. Defaults to
// "Too many requests".
Message string
}
func (e ConcurrencyLimitError) Error() string { return "concurrency limit reached" }
func (e ConcurrencyLimitError) StatusCode() int { return http.StatusTooManyRequests }
func (e ConcurrencyLimitError) Header() http.Header {
h := http.Header{}
h.Set("Content-Type", "application/json")
h.Set("Retry-After", e.retryAfter())
return h
}
func (e ConcurrencyLimitError) Body() []byte {
b, _ := json.Marshal(map[string]string{"error": e.message()})
return b
}
func (e ConcurrencyLimitError) retryAfter() string {
if e.RetryAfter > 0 {
return strconv.Itoa(e.RetryAfter)
}
return "1"
}
func (e ConcurrencyLimitError) message() string {
if e.Message != "" {
return e.Message
}
return "Too many requests"
}