schedule,shared: move concurrency 429 limits into scheduler code (#849)

- make concurrency limiting the scheduler.Scheduler's responsibility
- eliminate the separate concurrency limit middleware 
- move concurrencyLimit logic into scheduler.FIFO to maintain backwards compatibility
- add HTTPError from #834 

Updates #834
This commit is contained in:
Benson Wong
2026-06-15 22:35:12 -07:00
committed by GitHub
parent 8e84b2ec4f
commit 6cf1317341
14 changed files with 278 additions and 171 deletions
-57
View File
@@ -1,57 +0,0 @@
package server
import (
"net/http"
"golang.org/x/sync/semaphore"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// defaultConcurrencyLimit caps simultaneous in-flight requests per model when
// the model config leaves concurrencyLimit unset. Matches the legacy
// proxy.Process default.
const defaultConcurrencyLimit = 10
// CreateConcurrencyMiddleware returns middleware that limits simultaneous
// model-dispatched requests per model. Each model gets a semaphore sized to
// its concurrencyLimit (or defaultConcurrencyLimit). A request that cannot
// immediately acquire a slot is rejected with 429. Models without a local
// config entry (e.g. peer-routed models) are not limited.
func CreateConcurrencyMiddleware(cfg config.Config) chain.Middleware {
semaphores := make(map[string]*semaphore.Weighted, len(cfg.Models))
for id, mc := range cfg.Models {
limit := defaultConcurrencyLimit
if mc.ConcurrencyLimit > 0 {
limit = mc.ConcurrencyLimit
}
semaphores[id] = semaphore.NewWeighted(int64(limit))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := shared.FetchContext(r, cfg)
if err != nil {
shared.SendError(w, r, shared.ErrNoModelInContext)
return
}
// fall through for peer models
sem, ok := semaphores[data.ModelID]
if !ok {
next.ServeHTTP(w, r)
return
}
if !sem.TryAcquire(1) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"Too many requests"}`))
return
}
defer sem.Release(1)
next.ServeHTTP(w, r)
})
}
}
-75
View File
@@ -1,75 +0,0 @@
package server
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/shared"
)
func concurrencyTestReq(model string) *http.Request {
r := httptest.NewRequest("GET", "/v1/chat/completions", nil)
return r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: model, ModelID: model}))
}
func TestServer_ConcurrencyMiddleware_RejectsOverLimit(t *testing.T) {
cfg := config.Config{
Models: map[string]config.ModelConfig{
"m1": {ConcurrencyLimit: 1},
},
}
entered := make(chan struct{})
release := make(chan struct{})
var once sync.Once
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
once.Do(func() { close(entered) })
<-release
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
// First request occupies the only slot.
done := make(chan struct{})
go func() {
defer close(done)
h.ServeHTTP(httptest.NewRecorder(), concurrencyTestReq("m1"))
}()
<-entered
// Second concurrent request is rejected with 429.
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusTooManyRequests {
t.Fatalf("over-limit status = %d, want 429", w.Code)
}
// Once the slot frees, a new request succeeds.
close(release)
<-done
w = httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("m1"))
if w.Code != http.StatusOK {
t.Fatalf("post-release status = %d, want 200", w.Code)
}
}
func TestServer_ConcurrencyMiddleware_UnconfiguredModelPassesThrough(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{}}
called := 0
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called++
w.WriteHeader(http.StatusOK)
})
h := CreateConcurrencyMiddleware(cfg)(final)
w := httptest.NewRecorder()
h.ServeHTTP(w, concurrencyTestReq("peer-model"))
if w.Code != http.StatusOK || called != 1 {
t.Fatalf("unconfigured model: status=%d called=%d, want 200/1", w.Code, called)
}
}
-1
View File
@@ -177,7 +177,6 @@ func (s *Server) routes() {
modelChain := chain.New(
authMW,
CreateRequestContextMiddleware(s.cfg),
CreateConcurrencyMiddleware(s.cfg),
CreateFilterMiddleware(s.cfg),
CreateFormFilterMiddleware(s.cfg),
CreateInflightMiddleware(s.inflight),