Files
llama-swap/internal/router/base.go
T
Benson Wong 9be9a87fa0 internal/process: improve windows shutdown behaviour (#808)
Add Windows specific shutdown code paths so stopping of child processes
is more reliable:

- stopping llama-swap won't leave behind any child processes it created
- uses Job Objects in Windows so the whole llama-swap tree is closed by
the os
- add procCtx to baseRouter. It replaces shutdownCtx as a signal for
managing lifetime state.
- shutdownCtx is only used by the router to stop handling new requests
during shutdown
- improve debug logging to make it easier to trace source of issues

Fixes #804
Updates #807
2026-06-01 00:45:30 -07:00

801 lines
24 KiB
Go

package router
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type shutdownReq struct {
timeout time.Duration
respond chan error
}
type unloadReq struct {
targets []string
timeout time.Duration
respond chan struct{}
}
type handlerReq struct {
model string
ctx context.Context
respond chan handlerResp
positionCh chan int
}
type handlerResp struct {
handleFunc http.HandlerFunc
err error
}
type swapDone struct {
modelID string
err error
}
type serveDoneEvent struct {
modelID string
}
type activeSwap struct {
modelID string
evict []string
waiters []handlerReq
}
// swapPlanner is the only piece of behaviour that differs between concrete
// routers. baseRouter never inspects its internals.
type swapPlanner interface {
// EvictionFor returns running model IDs that must be stopped before
// target can serve. alsoRunning lists models the baseRouter has already
// committed to loading (in-flight swaps) which the planner cannot see
// via process.State() yet. Pure decision; must not log.
EvictionFor(target string, alsoRunning []string) []string
// OnSwapStart runs once at the start of every swap. Planners may log
// their decision here at whatever verbosity they choose.
OnSwapStart(target string)
}
// baseRouter owns the channels, run-loop, and orchestration code shared by
// every concrete router. Concrete routers embed *baseRouter and supply a
// swapPlanner that captures how their eviction set is decided.
type baseRouter struct {
name string
config config.Config
processes map[string]process.Process
logger *logmon.Monitor
planner swapPlanner
// shutdownCtx governs the request machinery: cancelling it tells grant()
// and ServeHTTP to stop granting and reject callers. It is deliberately
// separate from procCtx — see procCtx below.
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
// procCtx is the parent context for every managed process and governs
// process lifetime only. handleShutdown stops processes gracefully via
// Stop() and cancels procCtx afterwards, so teardown is never a context
// cancel racing the graceful path (which collapsed the grace to 100ms and
// let the caller return before children were reaped — see process run loop).
procCtx context.Context
procCancel context.CancelFunc
handlerCh chan handlerReq
shutdownCh chan shutdownReq
unloadCh chan unloadReq
swapDoneCh chan swapDone
serveDoneCh chan serveDoneEvent
runDone chan struct{}
// testProcessed, when non-nil, receives one event after each handlerReq
// or swapDone has been fully processed by run(). Tests use it to wait
// for run() to reach a deterministic state without sleeping. serveDone
// events are intentionally NOT signalled here so test event counts
// remain stable.
testProcessed chan struct{}
}
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
procCtx, procCancel := context.WithCancel(context.Background())
return &baseRouter{
name: name,
config: conf,
processes: processes,
logger: logger,
planner: planner,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
procCtx: procCtx,
procCancel: procCancel,
handlerCh: make(chan handlerReq),
shutdownCh: make(chan shutdownReq),
unloadCh: make(chan unloadReq),
swapDoneCh: make(chan swapDone),
serveDoneCh: make(chan serveDoneEvent),
runDone: make(chan struct{}),
}
}
func (b *baseRouter) notifyProcessed() {
if b.testProcessed != nil {
b.testProcessed <- struct{}{}
}
}
func (b *baseRouter) run() {
defer close(b.runDone)
active := make(map[string]*activeSwap)
inFlight := make(map[string]int)
var queued []handlerReq
for {
select {
case req := <-b.shutdownCh:
b.handleShutdown(req, active, queued)
return
case req := <-b.handlerCh:
b.handleRequest(req, active, inFlight, &queued)
b.notifyProcessed()
case req := <-b.unloadCh:
b.handleUnload(req, active, inFlight, &queued)
b.notifyProcessed()
case ev := <-b.swapDoneCh:
b.handleSwapDone(ev, active, inFlight, &queued)
b.notifyProcessed()
case ev := <-b.serveDoneCh:
b.handleServeDone(ev, active, inFlight, &queued)
}
}
}
// grant sends a response back to the caller of ServeHTTP and tells us
// whether the caller was still there to receive it.
//
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
// a select waiting on it. "Unbuffered" is the important word: a send only
// completes when the other side is actively receiving. So if this send
// succeeds, we know for a fact the caller picked up the response and will
// act on it. If the caller has already given up (its request context was
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
// down, the send never lands, one of the other select cases fires, and we
// report back that the grant did NOT happen.
//
// That distinction matters for in-flight bookkeeping — see grantHandler.
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
select {
case req.respond <- resp:
return true
case <-req.ctx.Done():
return false
case <-b.shutdownCtx.Done():
return false
}
}
// grantHandler is the "this caller can now use process p" path. It does
// two things that must stay locked together:
//
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
// HTTP request finishes, the run loop hears about it.
// 2. Bump inFlight[modelID] so the router knows this process is busy and
// refuses to evict it until the count comes back down.
//
// The increment is gated on grant() returning true. If grant() returns
// false, the caller already walked away and trackedServe will never run —
// which means no matching decrement will ever arrive on serveDoneCh.
// Incrementing in that case would strand the counter at >0 forever and
// the router would never again be willing to swap this model out.
//
// In short: increment if and only if we know a decrement is coming.
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
inFlight[modelID]++
}
}
// trackedServe is the wrapper that closes the loop on in-flight tracking.
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
// send on serveDoneCh after the handler returns. That send is what tells
// the run loop "this model now has one fewer request in flight — go look
// at the queue again, you may be able to start a swap you previously had
// to defer."
//
// The select on shutdownCtx.Done() is a release valve: if the router is
// already shutting down, nobody is reading serveDoneCh, so we drop the
// notification rather than blocking the HTTP goroutine forever.
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
select {
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
case <-b.shutdownCtx.Done():
}
}()
p.ServeHTTP(w, r)
}
}
// handleRequest decides what to do with one incoming ServeHTTP request. It is
// called from run() and never blocks indefinitely: any work that has to wait
// (starting a process, stopping siblings, waiting for ready) is deferred to
// a swap goroutine and reported back via swapDoneCh.
//
// The decision tree, in order:
//
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
// 2. A swap to the same model is already in flight — attach this waiter so
// one swap serves all callers that asked for the same model.
// 3. Fast path — the target process is already ready, the planner sees
// nothing to evict, and no in-flight swap is evicting it. Hand back its
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
// 4. Would collide with an in-flight swap (we'd stop their target, or
// they're stopping us) — park in the queue for handleSwapDone to drain.
// 5. Would evict a process that is still handling requests — park in the
// queue. handleServeDone will retry when the busy process drains.
// 6. Otherwise — start a new swap. This may run in parallel with other
// active swaps when their evict sets don't intersect.
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
// (1) Unknown model.
p, ok := b.processes[req.model]
if !ok {
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
return
}
// (2) Join an in-flight swap for the same model.
if s, ok := active[req.model]; ok {
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
s.waiters = append(s.waiters, req)
return
}
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
b.grantHandler(req, req.model, p, inFlight)
return
}
// (4) Collision with an in-flight swap — queue.
if collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
*queued = append(*queued, req)
b.broadcastQueuePositions(*queued)
return
}
// (5) Would evict a busy process — queue until it drains.
if conflictsWithInFlight(evict, inFlight) {
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
*queued = append(*queued, req)
b.broadcastQueuePositions(*queued)
return
}
// (6) Start a new (possibly parallel) swap.
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
s := b.startSwap(req, evict)
active[s.modelID] = s
}
// handleSwapDone is called from run() when a swap goroutine reports that it
// has finished. It fans out the result to every waiter that joined this swap,
// removes the swap from the active map, and then walks the queue once,
// promoting any items that no longer collide with the remaining active set.
// FIFO order is preserved: items still blocked stay in place.
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
s, ok := active[ev.modelID]
if !ok {
return
}
delete(active, ev.modelID)
for _, w := range s.waiters {
if ev.err != nil {
b.grant(w, handlerResp{err: ev.err})
} else {
p := b.processes[ev.modelID]
b.grantHandler(w, ev.modelID, p, inFlight)
}
}
b.drainQueue(active, inFlight, queued)
}
// handleServeDone is called from run() each time a tracked ServeHTTP
// finishes. It decrements the per-model in-flight count and, when that
// drops to zero, retries the queue: requests whose swap was deferred
// because they would have evicted this (now-idle) process can now proceed.
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
inFlight[ev.modelID]--
if inFlight[ev.modelID] <= 0 {
delete(inFlight, ev.modelID)
b.drainQueue(active, inFlight, queued)
}
}
// drainQueue walks the queued requests in order, re-running the handleRequest
// decision tree against the (now smaller) active set. Items that can now start
// or join become satisfied; items still blocked remain queued in original
// order so they get another chance on the next swap completion.
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
if len(*queued) == 0 {
return
}
pending := *queued
var remaining []handlerReq
for _, req := range pending {
p, ok := b.processes[req.model]
if !ok {
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
continue
}
if s, ok := active[req.model]; ok {
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
s.waiters = append(s.waiters, req)
continue
}
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
b.grantHandler(req, req.model, p, inFlight)
continue
}
if collidesWith(req.model, evict, active) {
remaining = append(remaining, req)
continue
}
if conflictsWithInFlight(evict, inFlight) {
remaining = append(remaining, req)
continue
}
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
s := b.startSwap(req, evict)
active[s.modelID] = s
}
*queued = remaining
b.broadcastQueuePositions(*queued)
}
// broadcastQueuePositions sends each queued request its current 1-indexed
// position. Sends are non-blocking: if the channel is full, the old value is
// drained first so the consumer always sees the latest position.
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
for i, req := range queued {
pos := i + 1
select {
case req.positionCh <- pos:
default:
select {
case <-req.positionCh:
default:
}
select {
case req.positionCh <- pos:
default:
}
}
}
}
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
swap := &activeSwap{
modelID: initial.model,
evict: evict,
waiters: []handlerReq{initial},
}
b.planner.OnSwapStart(initial.model)
go b.doSwap(initial.model, evict)
return swap
}
// activeTargets returns the IDs of every in-flight swap target except exclude.
// baseRouter passes this to the planner so eviction decisions account for
// models that have been committed to but have not yet transitioned to
// StateStarting in their process state machine.
func activeTargets(active map[string]*activeSwap, exclude string) []string {
if len(active) == 0 {
return nil
}
out := make([]string, 0, len(active))
for id := range active {
if id == exclude {
continue
}
out = append(out, id)
}
return out
}
// collidesWith reports whether a new swap with this target and evict set can
// safely run alongside the currently active swaps. Same-target callers should
// JOIN (handled before this) — they do not collide with themselves.
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
for id, s := range active {
if id == target {
continue
}
if containsString(evict, id) {
return true
}
if containsString(s.evict, target) {
return true
}
}
return false
}
// conflictsWithInFlight reports whether any model in evict is still handling
// requests. Stopping a busy process would cancel its callers' connections,
// so the router defers the swap until those callers finish.
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
for _, m := range evict {
if inFlight[m] > 0 {
return true
}
}
return false
}
func containsString(xs []string, s string) bool {
for _, x := range xs {
if x == s {
return true
}
}
return false
}
func (b *baseRouter) doSwap(modelID string, toStop []string) {
timeout := b.healthCheckTimeout()
var wg sync.WaitGroup
for _, mID := range toStop {
wg.Add(1)
go func(p process.Process, id string) {
defer wg.Done()
if err := p.Stop(timeout); err != nil {
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
}
}(b.processes[mID], mID)
}
wg.Wait()
target := b.processes[modelID]
if target.State() == process.StateStopped {
go func() {
if err := target.Run(timeout); err != nil {
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
}
}()
}
err := target.WaitReady(b.shutdownCtx)
select {
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
case <-b.shutdownCtx.Done():
}
}
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
// Cancel shutdownCtx first so any waiter that is currently parked on
// its respond channel can exit via its own shutdownCtx.Done() branch.
// The grant calls below then either land (waiter happened to receive
// before noticing shutdown) or fall through immediately via grant's
// shutdownCtx case — either way the waiter sees a non-OK response.
// This does NOT touch processes: their lifetime is procCtx, cancelled
// only after the graceful Stop() calls below have reaped them.
b.shutdownFn()
for _, s := range active {
for _, w := range s.waiters {
b.grant(w, handlerResp{err: shutdownErr})
}
}
for _, w := range queued {
b.grant(w, handlerResp{err: shutdownErr})
}
stopTimeout := req.timeout
if stopTimeout <= 0 {
stopTimeout = b.healthCheckTimeout()
}
var wg sync.WaitGroup
for i, p := range b.processes {
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(stopTimeout); err != nil {
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
}
}(i, p)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
if req.timeout > 0 {
select {
case <-done:
case <-time.After(req.timeout):
<-done
}
} else {
<-done
}
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
// the process run-loop goroutines exit; they are already StateStopped, so
// this is a clean no-op kill rather than a forced teardown.
b.procCancel()
req.respond <- nil
}
func (b *baseRouter) healthCheckTimeout() time.Duration {
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
if t <= 0 {
return 30 * time.Second
}
return t
}
func (b *baseRouter) Handles(model string) bool {
_, ok := b.processes[model]
return ok
}
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
if p, ok := b.processes[modelID]; ok {
return p.Logger(), true
}
return nil, false
}
// RunningModels returns the current state of every process that is not stopped
// or shut down. The processes map keys are fixed at construction and State()
// is a snapshot, so this is safe to call without the run loop.
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
running := make(map[string]process.ProcessState)
for id, p := range b.processes {
st := p.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
running[id] = st
}
return running
}
// Unload stops the named models, or every running model when none are named.
// It blocks until each targeted process has stopped.
//
// The request is funneled through the run loop so eviction is coordinated
// with the rest of the router's state: pending swap waiters for an
// unloaded model are released with an error, queued requests for unloaded
// models are dropped, and any deferred swaps that were waiting on those
// models become eligible to start.
//
// In-flight requests being served by an unloaded process are not waited
// for — Stop kills the upstream, those callers see whatever error the
// reverse proxy surfaces and may retry. Their trackedServe defers fire
// normally and decrement inFlight as the dying handlers return.
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
targets := models
if len(targets) == 0 {
targets = make([]string, 0, len(b.processes))
for id := range b.processes {
targets = append(targets, id)
}
}
if len(targets) == 0 {
return
}
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
select {
case b.unloadCh <- req:
case <-b.runDone:
return
}
<-req.respond
}
// handleUnload runs on the run loop in response to an Unload call. It
// reconciles router-owned state with the impending Stop, then performs
// the Stop synchronously so callers of Unload remain blocked until each
// targeted process has actually exited.
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
targetSet := make(map[string]bool, len(req.targets))
for _, id := range req.targets {
targetSet[id] = true
}
// Release waiters of any in-flight swap whose target is being
// unloaded. The swap goroutine itself is left to finish on its own;
// when its swapDone arrives, handleSwapDone will find no entry in
// active and silently drop it.
for id := range targetSet {
s, ok := active[id]
if !ok {
continue
}
for _, w := range s.waiters {
b.grant(w, handlerResp{err: unloadErr})
}
delete(active, id)
}
// Drop queued requests addressed to unloaded models. Requests for
// other models stay queued and may benefit from drainQueue at the end.
if len(*queued) > 0 {
kept := (*queued)[:0]
for _, w := range *queued {
if targetSet[w.model] {
b.grant(w, handlerResp{err: unloadErr})
continue
}
kept = append(kept, w)
}
*queued = kept
}
// Stop the targeted processes. Done synchronously so Unload's caller
// can rely on "after Unload returns, the process is stopped". inFlight
// is intentionally NOT cleared here: each dying handler will fire its
// trackedServe defer and reach handleServeDone in the normal way once
// the run loop is free again.
var wg sync.WaitGroup
for id := range targetSet {
p, ok := b.processes[id]
if !ok {
continue
}
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(req.timeout); err != nil {
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
}
}(id, p)
}
wg.Wait()
// Removing entries from active above may have unblocked queued
// requests that previously collided with the now-cancelled swaps.
b.drainQueue(active, inFlight, queued)
close(req.respond)
}
func (b *baseRouter) Shutdown(timeout time.Duration) error {
if !b.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("%s shutdown already in progress", b.name)
}
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
select {
case b.shutdownCh <- req:
case <-b.runDone:
return nil
}
return <-req.respond
}
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if b.shuttingDown.Load() {
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
data, err := FetchContext(req, b.config)
if err != nil {
SendError(w, req, err)
return
}
hr := handlerReq{
model: data.ModelID,
ctx: req.Context(),
// Unbuffered: a successful send on respond proves the waiter is
// alive and consuming. grant() relies on this to avoid handing a
// handleFunc to a cancelled waiter and leaking the inFlight count.
respond: make(chan handlerResp),
positionCh: make(chan int, 1),
}
select {
case b.handlerCh <- hr:
case <-req.Context().Done():
return
case <-b.shutdownCtx.Done():
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
isModelReady := false
if p, ok := b.processes[data.ModelID]; ok {
isModelReady = p.State() == process.StateReady
}
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
var lw *loadingWriter
cancelLoad := func() {}
if shouldShowLoading {
var swapCtx context.Context
swapCtx, cancelLoad = context.WithCancel(req.Context())
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
go lw.start(swapCtx)
go func() {
for {
select {
case pos := <-hr.positionCh:
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
case <-swapCtx.Done():
return
}
}
}()
}
// finishLoading stops the loading stream and fences its goroutine off from
// the ResponseWriter before the real handler (or ServeHTTP's return)
// reclaims it. release() must run even when waitForCompletion times out:
// otherwise a still-streaming goroutine flushes a finalized response and
// panics on the recycled *bufio.Writer.
finishLoading := func() {
cancelLoad()
if lw != nil {
lw.waitForCompletion(1 * time.Second)
lw.release()
}
}
var resp handlerResp
select {
case resp = <-hr.respond:
finishLoading()
case <-req.Context().Done():
finishLoading()
return
case <-b.shutdownCtx.Done():
finishLoading()
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
if resp.err != nil {
SendError(w, req, resp.err)
return
}
resp.handleFunc(w, req)
}