package scheduler import ( "errors" "io" "testing" "time" "github.com/mostlygeek/llama-swap/internal/logmon" "github.com/mostlygeek/llama-swap/internal/process" ) // Serial methods all run on the router's single run-loop goroutine, so these // tests drive them directly and synchronously, reusing fakeEffects and the // req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a // served request finishes via OnServeDone — the events the run loop delivers. func newSerial(eff Effects) *Serial { return NewSerial("test", logmon.NewWriter(io.Discard), eff) } // lastStart returns the most recent StartSwap record. func lastStart(t *testing.T, eff *fakeEffects) startRec { t.Helper() if len(eff.starts) == 0 { t.Fatal("no StartSwap recorded") } return eff.starts[len(eff.starts)-1] } func sameSet(a, b []string) bool { if len(a) != len(b) { return false } m := map[string]int{} for _, x := range a { m[x]++ } for _, x := range b { m[x]-- } for _, v := range m { if v != 0 { return false } } return true } // servedOrder returns the model IDs of every successful serve grant in order. func servedOrder(eff *fakeEffects) []string { var out []string for _, g := range eff.grants { if g.err == nil && g.serve { out = append(out, g.model) } } return out } func TestSerial_FastPath_AlreadyLoaded(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateReady s := newSerial(eff) s.OnRequest(req("a")) if got := len(eff.starts); got != 0 { t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got) } if got := eff.served("a"); got != 1 { t.Errorf("served(a)=%d want 1", got) } } func TestSerial_ColdStart_LoadsThenServes(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) if got := eff.startsFor("a"); got != 1 { t.Fatalf("StartSwap(a)=%d want 1", got) } if got := eff.served("a"); got != 0 { t.Errorf("served(a)=%d want 0 before load completes", got) } eff.states["a"] = process.StateReady s.OnSwapDone(SwapDone{ModelID: "a"}) if got := eff.served("a"); got != 1 { t.Errorf("served(a)=%d want 1 after load", got) } } func TestSerial_UnknownModel(t *testing.T) { eff := newFakeEffects() // no states => unknown s := newSerial(eff) s.OnRequest(req("ghost")) if len(eff.starts) != 0 { t.Errorf("StartSwap calls=%d want 0", len(eff.starts)) } if eff.errored("ghost") != 1 { t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost")) } if !errors.Is(eff.grants[0].err, ErrModelNotFound) { t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err) } } func TestSerial_EvictsEveryOtherModel(t *testing.T) { eff := newFakeEffects() eff.states["x"] = process.StateReady // already running eff.states["y"] = process.StateReady // also running (e.g. left over) eff.states["a"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) st := lastStart(t, eff) if st.model != "a" { t.Fatalf("loading %s want a", st.model) } if !sameSet(st.evict, []string{"x", "y"}) { t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict) } } // TestSerial_OneJobAtATime verifies a second request waits while the first is // serving, and only starts after the first finishes. func TestSerial_OneJobAtATime(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateReady eff.states["b"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) // served immediately s.OnRequest(req("b")) // must wait — a is serving if got := eff.startsFor("b"); got != 0 { t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got) } if got := eff.served("a"); got != 1 { t.Fatalf("served(a)=%d want 1", got) } // a finishes -> b may now load (evicting a). s.OnServeDone(ServeDoneEvent{ModelID: "a"}) if got := eff.startsFor("b"); got != 1 { t.Fatalf("StartSwap(b)=%d want 1 after a finished", got) } if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) { t.Errorf("b evict=%v want [a]", st.evict) } } // TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the // already-loaded model run without a reload, one after another. func TestSerial_SameModelConsecutive_NoReload(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) // cold load s.OnRequest(req("a")) // queued behind the first eff.states["a"] = process.StateReady s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves if got := eff.served("a"); got != 1 { t.Fatalf("served(a)=%d want 1 (one at a time)", got) } s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves if got := eff.served("a"); got != 2 { t.Fatalf("served(a)=%d want 2", got) } if got := eff.startsFor("a"); got != 1 { t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got) } } // TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl, // qwen36 execute in EXACTLY that order with evictions between each model switch, // including reloading qwen36 at the end even though it ran first. func TestSerial_StrictArrivalOrder(t *testing.T) { eff := newFakeEffects() for _, m := range []string{"qwen36", "qwen35", "sdxl"} { eff.states[m] = process.StateStopped } s := newSerial(eff) for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} { s.OnRequest(req(m)) } // Only the first job starts loading; the rest wait their turn. if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" { t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts) } // step completes the current model's load+serve and returns control to the // scheduler, which must start the next queued model. step := func(model string, wantEvict []string) { t.Helper() st := lastStart(t, eff) if st.model != model { t.Fatalf("loading %q want %q", st.model, model) } if !sameSet(st.evict, wantEvict) { t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict) } // Simulate the eviction + load actually happening. for _, e := range st.evict { eff.states[e] = process.StateStopped } eff.states[model] = process.StateReady s.OnSwapDone(SwapDone{ModelID: model}) s.OnServeDone(ServeDoneEvent{ModelID: model}) } step("qwen36", nil) // cold load, nothing else running step("qwen35", []string{"qwen36"}) // evict qwen36 step("sdxl", []string{"qwen35"}) // evict qwen35 step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl want := []string{"qwen36", "qwen35", "sdxl", "qwen36"} if got := servedOrder(eff); !sameOrder(got, want) { t.Fatalf("serve order=%v want %v", got, want) } } func sameOrder(a, b []string) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true } func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped eff.states["b"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) s.OnRequest(req("b")) // queued behind a // a's load fails: its caller is errored and b proceeds. s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")}) if eff.errored("a") != 1 { t.Fatalf("errored(a)=%d want 1", eff.errored("a")) } if got := eff.startsFor("b"); got != 1 { t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got) } } // TestSerial_GrantServeFalse_Advances verifies that when the active request's // caller has disconnected by serve time, the queue advances to the next request. func TestSerial_GrantServeFalse_Advances(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped eff.states["b"] = process.StateStopped eff.serveResult["a"] = false // a's caller is gone by grant time s := newSerial(eff) s.OnRequest(req("a")) s.OnRequest(req("b")) // queued eff.states["a"] = process.StateReady s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b if got := eff.served("a"); got != 0 { t.Errorf("served(a)=%d want 0 (caller gone)", got) } if got := eff.startsFor("b"); got != 1 { t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got) } } func TestSerial_OnCancel_QueuedRequest(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped eff.states["b"] = process.StateStopped s := newSerial(eff) s.OnRequest(reqCh("a")) // starts loading a cancelled := reqCh("b") s.OnRequest(cancelled) // queued behind a if len(s.queued) != 1 { t.Fatalf("queued=%d want 1", len(s.queued)) } s.OnCancel(cancelled) if len(s.queued) != 0 { t.Fatalf("queued=%d want 0 after cancel", len(s.queued)) } // a completes; b is gone, so nothing starts for it. eff.states["a"] = process.StateReady s.OnSwapDone(SwapDone{ModelID: "a"}) s.OnServeDone(ServeDoneEvent{ModelID: "a"}) if got := eff.startsFor("b"); got != 0 { t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got) } } func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped eff.states["b"] = process.StateStopped eff.states["c"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) // active (loading) s.OnRequest(req("b")) // queued s.OnRequest(req("c")) // queued s.OnShutdown(errors.New("shutting down")) if got := eff.errored(""); got != 3 { t.Errorf("error grants=%d want 3 (active load + 2 queued)", got) } if len(s.queued) != 0 { t.Errorf("queued=%d want 0 after shutdown", len(s.queued)) } } // TestSerial_OnUnload_WhileServing verifies that unloading the model that is // actively serving does not strand the queue: OnUnload stops the process but // leaves the active request to end via OnServeDone, which then advances. func TestSerial_OnUnload_WhileServing(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateReady eff.states["b"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) // served immediately (a ready) s.OnRequest(req("b")) // queued behind a if got := eff.served("a"); got != 1 { t.Fatalf("served(a)=%d want 1", got) } // Unload a while it is serving: the process is stopped, but the queue must // not advance yet — the active serve is still outstanding. s.OnUnload([]string{"a"}, time.Second) if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) { t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops) } if got := eff.startsFor("b"); got != 0 { t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got) } // The killed handler returns -> OnServeDone advances to b. eff.states["a"] = process.StateStopped s.OnServeDone(ServeDoneEvent{ModelID: "a"}) if got := eff.startsFor("b"); got != 1 { t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got) } } func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) { eff := newFakeEffects() eff.states["a"] = process.StateStopped eff.states["b"] = process.StateStopped s := newSerial(eff) s.OnRequest(req("a")) // active (loading a) s.OnRequest(req("b")) // queued // Unload a: its active load is failed and a is stopped. s.OnUnload([]string{"a"}, time.Second) if eff.errored("a") != 1 { t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a")) } if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) { t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops) } // b was queued and not unloaded; with a's load cancelled it now starts. if got := eff.startsFor("b"); got != 1 { t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got) } }