diff --git a/config-schema.json b/config-schema.json index 4142c397..d0dc2da0 100644 --- a/config-schema.json +++ b/config-schema.json @@ -601,10 +601,11 @@ "use": { "type": "string", "enum": [ + "serial", "fifo" ], - "default": "fifo", - "description": "Scheduler to use. Only 'fifo' is currently supported." + "default": "serial", + "description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models." }, "settings": { "type": "object", diff --git a/config.example.yaml b/config.example.yaml index 27de1131..f495cadd 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -556,11 +556,21 @@ routing: # expands to: [L] full: "L" - # scheduler: how queued requests are ordered. - # The default and only valid scheduler is "fifo" + # scheduler: how queued requests are ordered and run. + # - optional, default on this fork: "serial" + # - valid values: + # - "serial": strict one-model-at-a-time. Requests run in exact arrival + # order; only one request runs at a time; switching to a different model + # evicts every other running model first so a single model occupies memory + # at a time. This ignores group/matrix co-residency entirely. The "fifo" + # settings below (priority) do not apply. + # - "fifo": throughput-oriented. Same-model requests are batched to reduce + # swaps and a model serves up to its concurrencyLimit in parallel; models + # in non-exclusive groups can run concurrently. Requests may be reordered. scheduler: - use: fifo + use: serial settings: + # fifo settings only apply when use: fifo fifo: # priority: a dictionary of model ID -> priority # - optional, default: empty dictionary diff --git a/internal/config/config_posix_test.go b/internal/config/config_posix_test.go index d3fd9546..4f7591cd 100644 --- a/internal/config/config_posix_test.go +++ b/internal/config/config_posix_test.go @@ -277,7 +277,7 @@ groups: }, }, Scheduler: SchedulerConfig{ - Use: "fifo", + Use: "serial", }, }, } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2125d242..eef80b96 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1572,7 +1572,7 @@ groups: assert.Equal(t, "group", cfg.Routing.Router.Use) // default group injected for orphaned models (none here) still leaves g1 assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1") - assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use) + assert.Equal(t, "serial", cfg.Routing.Scheduler.Use) } func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) { @@ -1631,7 +1631,7 @@ func TestConfig_Routing_DefaultsToGroup(t *testing.T) { cfg, err := LoadConfigFromReader(strings.NewReader(twoModels)) require.NoError(t, err) assert.Equal(t, "group", cfg.Routing.Router.Use) - assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use) + assert.Equal(t, "serial", cfg.Routing.Scheduler.Use) } func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) { diff --git a/internal/config/config_windows_test.go b/internal/config/config_windows_test.go index 7f53a25d..f0d7cfe3 100644 --- a/internal/config/config_windows_test.go +++ b/internal/config/config_windows_test.go @@ -266,7 +266,7 @@ groups: }, }, Scheduler: SchedulerConfig{ - Use: "fifo", + Use: "serial", }, }, } diff --git a/internal/config/load.go b/internal/config/load.go index 6e9585d4..a56358e3 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -358,11 +358,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { config.Routing.Router.Settings.Matrix = config.Matrix config.Routing.Router.Settings.Groups = config.Groups + // This fork defaults to the "serial" scheduler: one model loaded at a time, + // requests served in strict arrival order. Set use: fifo for the upstream + // throughput-oriented behavior that batches same-model requests. if config.Routing.Scheduler.Use == "" { - config.Routing.Scheduler.Use = "fifo" + config.Routing.Scheduler.Use = "serial" } - if config.Routing.Scheduler.Use != "fifo" { - return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use) + switch config.Routing.Scheduler.Use { + case "fifo", "serial": + default: + return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use) } for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority { if _, found := config.RealModelName(modelID); !found { diff --git a/internal/router/scheduler/scheduler.go b/internal/router/scheduler/scheduler.go index 8b87db3b..86439ede 100644 --- a/internal/router/scheduler/scheduler.go +++ b/internal/router/scheduler/scheduler.go @@ -92,9 +92,14 @@ type Effects interface { StopProcesses(timeout time.Duration, ids []string) } -// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured -// from conf and bound to the given planner and effects. Currently only "fifo" -// (the default) is supported. +// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from +// conf and bound to the given planner and effects. Supported values are "fifo" +// (throughput-oriented, batches same-model requests) and "serial" (strict +// one-model-at-a-time, exact arrival order). +// +// The deployment default is applied by config loading (LoadConfig sets Use to +// "serial" when unset). The "" fallback here is the library default and remains +// "fifo" so callers that build a Config directly keep the original behavior. func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) { use := conf.Routing.Scheduler.Use if use == "" { @@ -103,6 +108,9 @@ func New(conf config.Config, name string, logger *logmon.Monitor, planner Swappe switch use { case "fifo": return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil + case "serial": + // Serial ignores the group planner: it always evicts every other model. + return NewSerial(name, logger, eff), nil default: return nil, fmt.Errorf("unsupported scheduler type: %q", use) } diff --git a/internal/router/scheduler/serial.go b/internal/router/scheduler/serial.go new file mode 100644 index 00000000..3b49ce8a --- /dev/null +++ b/internal/router/scheduler/serial.go @@ -0,0 +1,253 @@ +package scheduler + +import ( + "fmt" + "sort" + "time" + + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders +// or batches: requests run in exact arrival order and at most one request runs at +// any instant. When the next request targets a model other than the one loaded, +// every other running model is evicted and the target is loaded before it runs, +// so a single model occupies memory at a time — at the cost of throughput. +// +// Example: A B C A is served as A B C A. The final A reloads its model even +// though it ran first, because B and C displaced it in between. (FIFO, by +// contrast, would batch the two A requests: A A B C.) +// +// Serial ignores group/eviction policy entirely: it always evicts every other +// running model, regardless of how groups are configured. That is what makes the +// single-model guarantee a property of the scheduler rather than of the config. +// +// Like FIFO, every method runs on the router's single run-loop goroutine, so no +// internal locking is needed. +type Serial struct { + name string + logger *logmon.Monitor + effects Effects + + // queued holds requests in strict arrival order. It is never reordered. + queued []HandlerReq + + // active is the one request currently being processed (loading or serving), + // or nil when idle. phase is meaningful only while active != nil. + active *HandlerReq + phase serialPhase +} + +// serialPhase is the lifecycle stage of the active request. +type serialPhase int + +const ( + phaseIdle serialPhase = iota + phaseSwapping // waiting for OnSwapDone for active.Model + phaseServing // waiting for OnServeDone for active.Model +) + +// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always +// "stop every other running model", so the group planner is not consulted. +func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial { + return &Serial{ + name: name, + logger: logger, + effects: eff, + } +} + +// OnRequest validates the model and appends the request to the tail of the queue, +// then tries to start the next job. Unknown models fail immediately. +func (s *Serial) OnRequest(req HandlerReq) { + if _, ok := s.effects.ModelState(req.Model); !ok { + s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model) + s.effects.GrantError(req, ErrModelNotFound) + return + } + s.queued = append(s.queued, req) + broadcastQueuePositions(s.queued) + s.startNext() +} + +// startNext begins processing the head of the queue when nothing is active. It +// fast-paths a request whose model is already the sole loaded-and-ready process; +// otherwise it launches a swap that evicts every other running model first. The +// loop skips over requests for models that vanished (e.g. a config reload) and +// requests whose caller disconnected before they could be served. +func (s *Serial) startNext() { + if s.active != nil { + return // a job is already loading or serving + } + for len(s.queued) > 0 { + req := s.queued[0] + s.queued = s.queued[1:] + broadcastQueuePositions(s.queued) + + state, ok := s.effects.ModelState(req.Model) + if !ok { + s.effects.GrantError(req, ErrModelNotFound) + continue + } + + r := req + s.active = &r + + evict := s.otherRunning(req.Model) + if state == process.StateReady && len(evict) == 0 { + // Already loaded and the only model running — serve immediately. + s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model) + if s.serve() { + return + } + continue // caller gone; pick the next request + } + + s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict) + s.phase = phaseSwapping + s.effects.StartSwap(req.Model, evict) + return + } +} + +// serve hands the active request its tracked handler. It returns true when the +// request is now serving (await OnServeDone); false when the caller had already +// disconnected, in which case active is cleared so the next job can start. +func (s *Serial) serve() bool { + if s.effects.GrantServe(*s.active, s.active.Model) { + s.phase = phaseServing + return true + } + s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model) + s.active = nil + s.phase = phaseIdle + return false +} + +// OnSwapDone fires when the load for the active request completes. On success the +// request is served; on failure its caller receives the error and the queue +// advances. A SwapDone that does not match the active load (e.g. its request was +// unloaded or cancelled mid-load) is ignored. +func (s *Serial) OnSwapDone(ev SwapDone) { + if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID { + return + } + if ev.Err != nil { + s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err) + s.effects.GrantError(*s.active, ev.Err) + s.active = nil + s.phase = phaseIdle + s.startNext() + return + } + if !s.serve() { + s.startNext() // caller vanished while the model loaded; move on + } +} + +// OnServeDone fires when the active request's handler returns. The slot is freed +// and the next queued request begins. +func (s *Serial) OnServeDone(ev ServeDoneEvent) { + if s.active == nil || s.phase != phaseServing { + return + } + s.active = nil + s.phase = phaseIdle + s.startNext() +} + +// OnCancel removes a disconnected client's request from the queue. A request that +// is already active is left to finish: if it was loading, OnSwapDone's serve() +// will find the caller gone (GrantServe false) and advance; if it was serving, +// its handler returns normally and reaches OnServeDone. +func (s *Serial) OnCancel(req HandlerReq) { + if len(s.queued) == 0 { + return + } + kept := s.queued[:0] + removed := false + for _, q := range s.queued { + if q.Respond == req.Respond { + removed = true + continue + } + kept = append(kept, q) + } + s.queued = kept + if removed { + s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model) + broadcastQueuePositions(s.queued) + } +} + +// OnUnload reconciles state for an unload, stops the targeted processes, and +// advances the queue. It mirrors the FIFO contract: queued requests for unloaded +// models are failed; an active *loading* request for an unloaded model is failed +// (its swap goroutine is left to finish and its SwapDone is then ignored); an +// active *serving* request is left for its handler to end when StopProcesses +// kills the upstream. The Stop is synchronous so callers of Unload can rely on +// the processes being stopped on return. +func (s *Serial) OnUnload(targets []string, timeout time.Duration) { + unloadErr := fmt.Errorf("%s: model unloaded", s.name) + + targetSet := make(map[string]bool, len(targets)) + for _, id := range targets { + targetSet[id] = true + } + + if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] { + s.effects.GrantError(*s.active, unloadErr) + s.active = nil + s.phase = phaseIdle + } + + if len(s.queued) > 0 { + kept := s.queued[:0] + for _, q := range s.queued { + if targetSet[q.Model] { + s.effects.GrantError(q, unloadErr) + continue + } + kept = append(kept, q) + } + s.queued = kept + broadcastQueuePositions(s.queued) + } + + s.effects.StopProcesses(timeout, targets) + + // A still-serving active request advances via OnServeDone when its killed + // handler returns; only start the next job when nothing is active now. + if s.active == nil { + s.startNext() + } +} + +// OnShutdown grants err to every request the scheduler still holds: an active +// loading request and all queued requests. A serving request is torn down with +// its process by the baseRouter. +func (s *Serial) OnShutdown(err error) { + if s.active != nil && s.phase == phaseSwapping { + s.effects.GrantError(*s.active, err) + s.active = nil + s.phase = phaseIdle + } + for _, q := range s.queued { + s.effects.GrantError(q, err) + } + s.queued = nil +} + +// otherRunning returns every running model except target, sorted for +// deterministic eviction. +func (s *Serial) otherRunning(target string) []string { + var out []string + for id := range s.effects.RunningModels() { + if id != target { + out = append(out, id) + } + } + sort.Strings(out) + return out +} diff --git a/internal/router/scheduler/serial_test.go b/internal/router/scheduler/serial_test.go new file mode 100644 index 00000000..60de4e43 --- /dev/null +++ b/internal/router/scheduler/serial_test.go @@ -0,0 +1,391 @@ +package scheduler + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/internal/logmon" + "github.com/mostlygeek/llama-swap/internal/process" +) + +// Serial methods all run on the router's single run-loop goroutine, so these +// tests drive them directly and synchronously, reusing fakeEffects and the +// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a +// served request finishes via OnServeDone — the events the run loop delivers. + +func newSerial(eff Effects) *Serial { + return NewSerial("test", logmon.NewWriter(io.Discard), eff) +} + +// lastStart returns the most recent StartSwap record. +func lastStart(t *testing.T, eff *fakeEffects) startRec { + t.Helper() + if len(eff.starts) == 0 { + t.Fatal("no StartSwap recorded") + } + return eff.starts[len(eff.starts)-1] +} + +func sameSet(a, b []string) bool { + if len(a) != len(b) { + return false + } + m := map[string]int{} + for _, x := range a { + m[x]++ + } + for _, x := range b { + m[x]-- + } + for _, v := range m { + if v != 0 { + return false + } + } + return true +} + +// servedOrder returns the model IDs of every successful serve grant in order. +func servedOrder(eff *fakeEffects) []string { + var out []string + for _, g := range eff.grants { + if g.err == nil && g.serve { + out = append(out, g.model) + } + } + return out +} + +func TestSerial_FastPath_AlreadyLoaded(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateReady + s := newSerial(eff) + + s.OnRequest(req("a")) + + if got := len(eff.starts); got != 0 { + t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got) + } + if got := eff.served("a"); got != 1 { + t.Errorf("served(a)=%d want 1", got) + } +} + +func TestSerial_ColdStart_LoadsThenServes(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) + if got := eff.startsFor("a"); got != 1 { + t.Fatalf("StartSwap(a)=%d want 1", got) + } + if got := eff.served("a"); got != 0 { + t.Errorf("served(a)=%d want 0 before load completes", got) + } + + eff.states["a"] = process.StateReady + s.OnSwapDone(SwapDone{ModelID: "a"}) + if got := eff.served("a"); got != 1 { + t.Errorf("served(a)=%d want 1 after load", got) + } +} + +func TestSerial_UnknownModel(t *testing.T) { + eff := newFakeEffects() // no states => unknown + s := newSerial(eff) + + s.OnRequest(req("ghost")) + + if len(eff.starts) != 0 { + t.Errorf("StartSwap calls=%d want 0", len(eff.starts)) + } + if eff.errored("ghost") != 1 { + t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost")) + } + if !errors.Is(eff.grants[0].err, ErrModelNotFound) { + t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err) + } +} + +func TestSerial_EvictsEveryOtherModel(t *testing.T) { + eff := newFakeEffects() + eff.states["x"] = process.StateReady // already running + eff.states["y"] = process.StateReady // also running (e.g. left over) + eff.states["a"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) + + st := lastStart(t, eff) + if st.model != "a" { + t.Fatalf("loading %s want a", st.model) + } + if !sameSet(st.evict, []string{"x", "y"}) { + t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict) + } +} + +// TestSerial_OneJobAtATime verifies a second request waits while the first is +// serving, and only starts after the first finishes. +func TestSerial_OneJobAtATime(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateReady + eff.states["b"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) // served immediately + s.OnRequest(req("b")) // must wait — a is serving + + if got := eff.startsFor("b"); got != 0 { + t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got) + } + if got := eff.served("a"); got != 1 { + t.Fatalf("served(a)=%d want 1", got) + } + + // a finishes -> b may now load (evicting a). + s.OnServeDone(ServeDoneEvent{ModelID: "a"}) + if got := eff.startsFor("b"); got != 1 { + t.Fatalf("StartSwap(b)=%d want 1 after a finished", got) + } + if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) { + t.Errorf("b evict=%v want [a]", st.evict) + } +} + +// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the +// already-loaded model run without a reload, one after another. +func TestSerial_SameModelConsecutive_NoReload(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) // cold load + s.OnRequest(req("a")) // queued behind the first + + eff.states["a"] = process.StateReady + s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves + if got := eff.served("a"); got != 1 { + t.Fatalf("served(a)=%d want 1 (one at a time)", got) + } + + s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves + if got := eff.served("a"); got != 2 { + t.Fatalf("served(a)=%d want 2", got) + } + if got := eff.startsFor("a"); got != 1 { + t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got) + } +} + +// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl, +// qwen36 execute in EXACTLY that order with evictions between each model switch, +// including reloading qwen36 at the end even though it ran first. +func TestSerial_StrictArrivalOrder(t *testing.T) { + eff := newFakeEffects() + for _, m := range []string{"qwen36", "qwen35", "sdxl"} { + eff.states[m] = process.StateStopped + } + s := newSerial(eff) + + for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} { + s.OnRequest(req(m)) + } + + // Only the first job starts loading; the rest wait their turn. + if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" { + t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts) + } + + // step completes the current model's load+serve and returns control to the + // scheduler, which must start the next queued model. + step := func(model string, wantEvict []string) { + t.Helper() + st := lastStart(t, eff) + if st.model != model { + t.Fatalf("loading %q want %q", st.model, model) + } + if !sameSet(st.evict, wantEvict) { + t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict) + } + // Simulate the eviction + load actually happening. + for _, e := range st.evict { + eff.states[e] = process.StateStopped + } + eff.states[model] = process.StateReady + s.OnSwapDone(SwapDone{ModelID: model}) + s.OnServeDone(ServeDoneEvent{ModelID: model}) + } + + step("qwen36", nil) // cold load, nothing else running + step("qwen35", []string{"qwen36"}) // evict qwen36 + step("sdxl", []string{"qwen35"}) // evict qwen35 + step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl + + want := []string{"qwen36", "qwen35", "sdxl", "qwen36"} + if got := servedOrder(eff); !sameOrder(got, want) { + t.Fatalf("serve order=%v want %v", got, want) + } +} + +func sameOrder(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + eff.states["b"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) + s.OnRequest(req("b")) // queued behind a + + // a's load fails: its caller is errored and b proceeds. + s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")}) + if eff.errored("a") != 1 { + t.Fatalf("errored(a)=%d want 1", eff.errored("a")) + } + if got := eff.startsFor("b"); got != 1 { + t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got) + } +} + +// TestSerial_GrantServeFalse_Advances verifies that when the active request's +// caller has disconnected by serve time, the queue advances to the next request. +func TestSerial_GrantServeFalse_Advances(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + eff.states["b"] = process.StateStopped + eff.serveResult["a"] = false // a's caller is gone by grant time + s := newSerial(eff) + + s.OnRequest(req("a")) + s.OnRequest(req("b")) // queued + + eff.states["a"] = process.StateReady + s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b + + if got := eff.served("a"); got != 0 { + t.Errorf("served(a)=%d want 0 (caller gone)", got) + } + if got := eff.startsFor("b"); got != 1 { + t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got) + } +} + +func TestSerial_OnCancel_QueuedRequest(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + eff.states["b"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(reqCh("a")) // starts loading a + cancelled := reqCh("b") + s.OnRequest(cancelled) // queued behind a + if len(s.queued) != 1 { + t.Fatalf("queued=%d want 1", len(s.queued)) + } + + s.OnCancel(cancelled) + if len(s.queued) != 0 { + t.Fatalf("queued=%d want 0 after cancel", len(s.queued)) + } + + // a completes; b is gone, so nothing starts for it. + eff.states["a"] = process.StateReady + s.OnSwapDone(SwapDone{ModelID: "a"}) + s.OnServeDone(ServeDoneEvent{ModelID: "a"}) + if got := eff.startsFor("b"); got != 0 { + t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got) + } +} + +func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + eff.states["b"] = process.StateStopped + eff.states["c"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) // active (loading) + s.OnRequest(req("b")) // queued + s.OnRequest(req("c")) // queued + + s.OnShutdown(errors.New("shutting down")) + + if got := eff.errored(""); got != 3 { + t.Errorf("error grants=%d want 3 (active load + 2 queued)", got) + } + if len(s.queued) != 0 { + t.Errorf("queued=%d want 0 after shutdown", len(s.queued)) + } +} + +// TestSerial_OnUnload_WhileServing verifies that unloading the model that is +// actively serving does not strand the queue: OnUnload stops the process but +// leaves the active request to end via OnServeDone, which then advances. +func TestSerial_OnUnload_WhileServing(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateReady + eff.states["b"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) // served immediately (a ready) + s.OnRequest(req("b")) // queued behind a + if got := eff.served("a"); got != 1 { + t.Fatalf("served(a)=%d want 1", got) + } + + // Unload a while it is serving: the process is stopped, but the queue must + // not advance yet — the active serve is still outstanding. + s.OnUnload([]string{"a"}, time.Second) + if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) { + t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops) + } + if got := eff.startsFor("b"); got != 0 { + t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got) + } + + // The killed handler returns -> OnServeDone advances to b. + eff.states["a"] = process.StateStopped + s.OnServeDone(ServeDoneEvent{ModelID: "a"}) + if got := eff.startsFor("b"); got != 1 { + t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got) + } +} + +func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) { + eff := newFakeEffects() + eff.states["a"] = process.StateStopped + eff.states["b"] = process.StateStopped + s := newSerial(eff) + + s.OnRequest(req("a")) // active (loading a) + s.OnRequest(req("b")) // queued + + // Unload a: its active load is failed and a is stopped. + s.OnUnload([]string{"a"}, time.Second) + + if eff.errored("a") != 1 { + t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a")) + } + if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) { + t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops) + } + // b was queued and not unloaded; with a's load cancelled it now starts. + if got := eff.startsFor("b"); got != 1 { + t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got) + } +}