Introduce new routing backend (#790)

This is a huge backend change that essentially started with rewriting
the concurrency handling for processes and blew up to a refactor of the
entire application. In short these are the improvements:

**Better state and life cycle management:** 

Life cycle management of processes has always been the trickiest part of
the code. Juggling mutex locks between multiple locations to reduce race
conditions was complex. Too complex for my feeble brain to build a
simple mental model around as llama-swap gained more features. All of
that has been refactored. Most of the locks are gone, replaced with a
single run() that owns all state changes. There is one place to start
from now to understand and extend routing logic.

The improved life cycle management makes it easier to implement more
complex swap optimization strategies in the future like #727.

**Collation of requests:**

llama-swap previously handled requests and swapping in the order they
came in. For example requests for models in this order ABCABC would
result in 5 swaps. Now those requests are handled in this order AABBCC.
The result is less time waiting for swap under a high churn request
queue. This fixes #588 #612.

A possible future enhancement is to support a starvation parameter so
swap can be forced when models have been waiting too long.

**Shared base implementation for groups and swap matrix:** 

During the refactor it became clear that much of the swapping logic was
shared between these two implementations. That is not surprising
considering the swap matrix was added many moons after groups. Now they
share a common base and their specific swap strategies are implemented
into the swapPlanner interface.

Requests for bespoke or specific swapping scenarios is a common theme in
the issues. Now users can implement whatever bespoke and weird swapping
strategy they want in their own fork. Just ask your agent of choice to
implement swapPlanner. I'll still remaining more conservative on what
actually lands in core llama-swap and will continue to evaluate PRs if
the changes is good for everyone or just one specific use case.

**AI / Agentic Disclosure:** 

I paid very close attention to the low level swap concurrency design and
implementation. It's important to keep that essential part reliable,
boring and no surprises. Backwards compatibility was also maintained,
even the one way non-exclusive group model loading behaviour that people
have rightly pointed out be a weird design decision.

With the underlying swap core done the web server, api and UI sitting on
top were largely ported over with Claude Code and Opus 4.7 in multiple
phases. If you're curious I kept the changes in docs/newrouter-todo.md.
I did several passes to make sure things weren't left behind.

However, even frontier LLMs at the time of this PR still make small
decisions that don't make a lot of sense. They get shit wrong all the
time, just in small subtle way.

That said, there's likely to be some new bugs introduced with this
massive refactor. I'm fairly confident that there's no major
architectural flaws that would cause goal seeking agents to make dumb,
ugly code decisions.

For a little while the legacy llama-swap will be available under
cmd/legacy/llama-swap. The plan is to eventually delete that entry point
as well as the proxy package.

On a bit of a personal note, this PR is exciting and a bit sad for me. I
hand wrote much of the original code and this PR ultimately replaces
much of it. While the old code served as a good reference for the agent
to implement the new stuff it still a bit sad to eventually delete it
all.
This commit is contained in:
Benson Wong
2026-05-28 21:47:01 -07:00
committed by GitHub
parent 63bc266395
commit 02e015fa49
107 changed files with 12014 additions and 251 deletions
+775
View File
@@ -0,0 +1,775 @@
package router
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type shutdownReq struct {
timeout time.Duration
respond chan error
}
type unloadReq struct {
targets []string
timeout time.Duration
respond chan struct{}
}
type handlerReq struct {
model string
ctx context.Context
respond chan handlerResp
positionCh chan int
}
type handlerResp struct {
handleFunc http.HandlerFunc
err error
}
type swapDone struct {
modelID string
err error
}
type serveDoneEvent struct {
modelID string
}
type activeSwap struct {
modelID string
evict []string
waiters []handlerReq
}
// swapPlanner is the only piece of behaviour that differs between concrete
// routers. baseRouter never inspects its internals.
type swapPlanner interface {
// EvictionFor returns running model IDs that must be stopped before
// target can serve. alsoRunning lists models the baseRouter has already
// committed to loading (in-flight swaps) which the planner cannot see
// via process.State() yet. Pure decision; must not log.
EvictionFor(target string, alsoRunning []string) []string
// OnSwapStart runs once at the start of every swap. Planners may log
// their decision here at whatever verbosity they choose.
OnSwapStart(target string)
}
// baseRouter owns the channels, run-loop, and orchestration code shared by
// every concrete router. Concrete routers embed *baseRouter and supply a
// swapPlanner that captures how their eviction set is decided.
type baseRouter struct {
name string
config config.Config
processes map[string]process.Process
logger *logmon.Monitor
planner swapPlanner
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
handlerCh chan handlerReq
shutdownCh chan shutdownReq
unloadCh chan unloadReq
swapDoneCh chan swapDone
serveDoneCh chan serveDoneEvent
runDone chan struct{}
// testProcessed, when non-nil, receives one event after each handlerReq
// or swapDone has been fully processed by run(). Tests use it to wait
// for run() to reach a deterministic state without sleeping. serveDone
// events are intentionally NOT signalled here so test event counts
// remain stable.
testProcessed chan struct{}
}
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
return &baseRouter{
name: name,
config: conf,
processes: processes,
logger: logger,
planner: planner,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
handlerCh: make(chan handlerReq),
shutdownCh: make(chan shutdownReq),
unloadCh: make(chan unloadReq),
swapDoneCh: make(chan swapDone),
serveDoneCh: make(chan serveDoneEvent),
runDone: make(chan struct{}),
}
}
func (b *baseRouter) notifyProcessed() {
if b.testProcessed != nil {
b.testProcessed <- struct{}{}
}
}
func (b *baseRouter) run() {
defer close(b.runDone)
active := make(map[string]*activeSwap)
inFlight := make(map[string]int)
var queued []handlerReq
for {
select {
case req := <-b.shutdownCh:
b.handleShutdown(req, active, queued)
return
case req := <-b.handlerCh:
b.handleRequest(req, active, inFlight, &queued)
b.notifyProcessed()
case req := <-b.unloadCh:
b.handleUnload(req, active, inFlight, &queued)
b.notifyProcessed()
case ev := <-b.swapDoneCh:
b.handleSwapDone(ev, active, inFlight, &queued)
b.notifyProcessed()
case ev := <-b.serveDoneCh:
b.handleServeDone(ev, active, inFlight, &queued)
}
}
}
// grant sends a response back to the caller of ServeHTTP and tells us
// whether the caller was still there to receive it.
//
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
// a select waiting on it. "Unbuffered" is the important word: a send only
// completes when the other side is actively receiving. So if this send
// succeeds, we know for a fact the caller picked up the response and will
// act on it. If the caller has already given up (its request context was
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
// down, the send never lands, one of the other select cases fires, and we
// report back that the grant did NOT happen.
//
// That distinction matters for in-flight bookkeeping — see grantHandler.
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
select {
case req.respond <- resp:
return true
case <-req.ctx.Done():
return false
case <-b.shutdownCtx.Done():
return false
}
}
// grantHandler is the "this caller can now use process p" path. It does
// two things that must stay locked together:
//
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
// HTTP request finishes, the run loop hears about it.
// 2. Bump inFlight[modelID] so the router knows this process is busy and
// refuses to evict it until the count comes back down.
//
// The increment is gated on grant() returning true. If grant() returns
// false, the caller already walked away and trackedServe will never run —
// which means no matching decrement will ever arrive on serveDoneCh.
// Incrementing in that case would strand the counter at >0 forever and
// the router would never again be willing to swap this model out.
//
// In short: increment if and only if we know a decrement is coming.
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
inFlight[modelID]++
}
}
// trackedServe is the wrapper that closes the loop on in-flight tracking.
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
// send on serveDoneCh after the handler returns. That send is what tells
// the run loop "this model now has one fewer request in flight — go look
// at the queue again, you may be able to start a swap you previously had
// to defer."
//
// The select on shutdownCtx.Done() is a release valve: if the router is
// already shutting down, nobody is reading serveDoneCh, so we drop the
// notification rather than blocking the HTTP goroutine forever.
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
select {
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
case <-b.shutdownCtx.Done():
}
}()
p.ServeHTTP(w, r)
}
}
// handleRequest decides what to do with one incoming ServeHTTP request. It is
// called from run() and never blocks indefinitely: any work that has to wait
// (starting a process, stopping siblings, waiting for ready) is deferred to
// a swap goroutine and reported back via swapDoneCh.
//
// The decision tree, in order:
//
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
// 2. A swap to the same model is already in flight — attach this waiter so
// one swap serves all callers that asked for the same model.
// 3. Fast path — the target process is already ready, the planner sees
// nothing to evict, and no in-flight swap is evicting it. Hand back its
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
// 4. Would collide with an in-flight swap (we'd stop their target, or
// they're stopping us) — park in the queue for handleSwapDone to drain.
// 5. Would evict a process that is still handling requests — park in the
// queue. handleServeDone will retry when the busy process drains.
// 6. Otherwise — start a new swap. This may run in parallel with other
// active swaps when their evict sets don't intersect.
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
// (1) Unknown model.
p, ok := b.processes[req.model]
if !ok {
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
return
}
// (2) Join an in-flight swap for the same model.
if s, ok := active[req.model]; ok {
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
s.waiters = append(s.waiters, req)
return
}
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
b.grantHandler(req, req.model, p, inFlight)
return
}
// (4) Collision with an in-flight swap — queue.
if collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
*queued = append(*queued, req)
b.broadcastQueuePositions(*queued)
return
}
// (5) Would evict a busy process — queue until it drains.
if conflictsWithInFlight(evict, inFlight) {
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
*queued = append(*queued, req)
b.broadcastQueuePositions(*queued)
return
}
// (6) Start a new (possibly parallel) swap.
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
s := b.startSwap(req, evict)
active[s.modelID] = s
}
// handleSwapDone is called from run() when a swap goroutine reports that it
// has finished. It fans out the result to every waiter that joined this swap,
// removes the swap from the active map, and then walks the queue once,
// promoting any items that no longer collide with the remaining active set.
// FIFO order is preserved: items still blocked stay in place.
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
s, ok := active[ev.modelID]
if !ok {
return
}
delete(active, ev.modelID)
for _, w := range s.waiters {
if ev.err != nil {
b.grant(w, handlerResp{err: ev.err})
} else {
p := b.processes[ev.modelID]
b.grantHandler(w, ev.modelID, p, inFlight)
}
}
b.drainQueue(active, inFlight, queued)
}
// handleServeDone is called from run() each time a tracked ServeHTTP
// finishes. It decrements the per-model in-flight count and, when that
// drops to zero, retries the queue: requests whose swap was deferred
// because they would have evicted this (now-idle) process can now proceed.
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
inFlight[ev.modelID]--
if inFlight[ev.modelID] <= 0 {
delete(inFlight, ev.modelID)
b.drainQueue(active, inFlight, queued)
}
}
// drainQueue walks the queued requests in order, re-running the handleRequest
// decision tree against the (now smaller) active set. Items that can now start
// or join become satisfied; items still blocked remain queued in original
// order so they get another chance on the next swap completion.
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
if len(*queued) == 0 {
return
}
pending := *queued
var remaining []handlerReq
for _, req := range pending {
p, ok := b.processes[req.model]
if !ok {
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
continue
}
if s, ok := active[req.model]; ok {
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
s.waiters = append(s.waiters, req)
continue
}
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
b.grantHandler(req, req.model, p, inFlight)
continue
}
if collidesWith(req.model, evict, active) {
remaining = append(remaining, req)
continue
}
if conflictsWithInFlight(evict, inFlight) {
remaining = append(remaining, req)
continue
}
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
s := b.startSwap(req, evict)
active[s.modelID] = s
}
*queued = remaining
b.broadcastQueuePositions(*queued)
}
// broadcastQueuePositions sends each queued request its current 1-indexed
// position. Sends are non-blocking: if the channel is full, the old value is
// drained first so the consumer always sees the latest position.
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
for i, req := range queued {
pos := i + 1
select {
case req.positionCh <- pos:
default:
select {
case <-req.positionCh:
default:
}
select {
case req.positionCh <- pos:
default:
}
}
}
}
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
swap := &activeSwap{
modelID: initial.model,
evict: evict,
waiters: []handlerReq{initial},
}
b.planner.OnSwapStart(initial.model)
go b.doSwap(initial.model, evict)
return swap
}
// activeTargets returns the IDs of every in-flight swap target except exclude.
// baseRouter passes this to the planner so eviction decisions account for
// models that have been committed to but have not yet transitioned to
// StateStarting in their process state machine.
func activeTargets(active map[string]*activeSwap, exclude string) []string {
if len(active) == 0 {
return nil
}
out := make([]string, 0, len(active))
for id := range active {
if id == exclude {
continue
}
out = append(out, id)
}
return out
}
// collidesWith reports whether a new swap with this target and evict set can
// safely run alongside the currently active swaps. Same-target callers should
// JOIN (handled before this) — they do not collide with themselves.
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
for id, s := range active {
if id == target {
continue
}
if containsString(evict, id) {
return true
}
if containsString(s.evict, target) {
return true
}
}
return false
}
// conflictsWithInFlight reports whether any model in evict is still handling
// requests. Stopping a busy process would cancel its callers' connections,
// so the router defers the swap until those callers finish.
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
for _, m := range evict {
if inFlight[m] > 0 {
return true
}
}
return false
}
func containsString(xs []string, s string) bool {
for _, x := range xs {
if x == s {
return true
}
}
return false
}
func (b *baseRouter) doSwap(modelID string, toStop []string) {
timeout := b.healthCheckTimeout()
var wg sync.WaitGroup
for _, mID := range toStop {
wg.Add(1)
go func(p process.Process, id string) {
defer wg.Done()
if err := p.Stop(timeout); err != nil {
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
}
}(b.processes[mID], mID)
}
wg.Wait()
target := b.processes[modelID]
if target.State() == process.StateStopped {
go func() {
if err := target.Run(timeout); err != nil {
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
}
}()
}
err := target.WaitReady(b.shutdownCtx)
select {
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
case <-b.shutdownCtx.Done():
}
}
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
// Cancel shutdownCtx first so any waiter that is currently parked on
// its respond channel can exit via its own shutdownCtx.Done() branch.
// The grant calls below then either land (waiter happened to receive
// before noticing shutdown) or fall through immediately via grant's
// shutdownCtx case — either way the waiter sees a non-OK response.
b.shutdownFn()
for _, s := range active {
for _, w := range s.waiters {
b.grant(w, handlerResp{err: shutdownErr})
}
}
for _, w := range queued {
b.grant(w, handlerResp{err: shutdownErr})
}
stopTimeout := req.timeout
if stopTimeout <= 0 {
stopTimeout = b.healthCheckTimeout()
}
var wg sync.WaitGroup
for i, p := range b.processes {
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(stopTimeout); err != nil {
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
}
}(i, p)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
if req.timeout > 0 {
select {
case <-done:
case <-time.After(req.timeout):
<-done
}
} else {
<-done
}
req.respond <- nil
}
func (b *baseRouter) healthCheckTimeout() time.Duration {
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
if t <= 0 {
return 30 * time.Second
}
return t
}
func (b *baseRouter) Handles(model string) bool {
_, ok := b.processes[model]
return ok
}
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
if p, ok := b.processes[modelID]; ok {
return p.Logger(), true
}
return nil, false
}
// RunningModels returns the current state of every process that is not stopped
// or shut down. The processes map keys are fixed at construction and State()
// is a snapshot, so this is safe to call without the run loop.
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
running := make(map[string]process.ProcessState)
for id, p := range b.processes {
st := p.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
running[id] = st
}
return running
}
// Unload stops the named models, or every running model when none are named.
// It blocks until each targeted process has stopped.
//
// The request is funneled through the run loop so eviction is coordinated
// with the rest of the router's state: pending swap waiters for an
// unloaded model are released with an error, queued requests for unloaded
// models are dropped, and any deferred swaps that were waiting on those
// models become eligible to start.
//
// In-flight requests being served by an unloaded process are not waited
// for — Stop kills the upstream, those callers see whatever error the
// reverse proxy surfaces and may retry. Their trackedServe defers fire
// normally and decrement inFlight as the dying handlers return.
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
targets := models
if len(targets) == 0 {
targets = make([]string, 0, len(b.processes))
for id := range b.processes {
targets = append(targets, id)
}
}
if len(targets) == 0 {
return
}
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
select {
case b.unloadCh <- req:
case <-b.runDone:
return
}
<-req.respond
}
// handleUnload runs on the run loop in response to an Unload call. It
// reconciles router-owned state with the impending Stop, then performs
// the Stop synchronously so callers of Unload remain blocked until each
// targeted process has actually exited.
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
targetSet := make(map[string]bool, len(req.targets))
for _, id := range req.targets {
targetSet[id] = true
}
// Release waiters of any in-flight swap whose target is being
// unloaded. The swap goroutine itself is left to finish on its own;
// when its swapDone arrives, handleSwapDone will find no entry in
// active and silently drop it.
for id := range targetSet {
s, ok := active[id]
if !ok {
continue
}
for _, w := range s.waiters {
b.grant(w, handlerResp{err: unloadErr})
}
delete(active, id)
}
// Drop queued requests addressed to unloaded models. Requests for
// other models stay queued and may benefit from drainQueue at the end.
if len(*queued) > 0 {
kept := (*queued)[:0]
for _, w := range *queued {
if targetSet[w.model] {
b.grant(w, handlerResp{err: unloadErr})
continue
}
kept = append(kept, w)
}
*queued = kept
}
// Stop the targeted processes. Done synchronously so Unload's caller
// can rely on "after Unload returns, the process is stopped". inFlight
// is intentionally NOT cleared here: each dying handler will fire its
// trackedServe defer and reach handleServeDone in the normal way once
// the run loop is free again.
var wg sync.WaitGroup
for id := range targetSet {
p, ok := b.processes[id]
if !ok {
continue
}
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(req.timeout); err != nil {
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
}
}(id, p)
}
wg.Wait()
// Removing entries from active above may have unblocked queued
// requests that previously collided with the now-cancelled swaps.
b.drainQueue(active, inFlight, queued)
close(req.respond)
}
func (b *baseRouter) Shutdown(timeout time.Duration) error {
if !b.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("%s shutdown already in progress", b.name)
}
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
select {
case b.shutdownCh <- req:
case <-b.runDone:
return nil
}
return <-req.respond
}
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if b.shuttingDown.Load() {
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
data, err := FetchContext(req, b.config)
if err != nil {
SendError(w, req, err)
return
}
hr := handlerReq{
model: data.ModelID,
ctx: req.Context(),
// Unbuffered: a successful send on respond proves the waiter is
// alive and consuming. grant() relies on this to avoid handing a
// handleFunc to a cancelled waiter and leaking the inFlight count.
respond: make(chan handlerResp),
positionCh: make(chan int, 1),
}
select {
case b.handlerCh <- hr:
case <-req.Context().Done():
return
case <-b.shutdownCtx.Done():
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
isModelReady := false
if p, ok := b.processes[data.ModelID]; ok {
isModelReady = p.State() == process.StateReady
}
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
var lw *loadingWriter
cancelLoad := func() {}
if shouldShowLoading {
var swapCtx context.Context
swapCtx, cancelLoad = context.WithCancel(req.Context())
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
go lw.start(swapCtx)
go func() {
for {
select {
case pos := <-hr.positionCh:
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
case <-swapCtx.Done():
return
}
}
}()
}
var resp handlerResp
select {
case resp = <-hr.respond:
cancelLoad()
if lw != nil {
lw.waitForCompletion(1 * time.Second)
}
case <-req.Context().Done():
cancelLoad()
if lw != nil {
lw.waitForCompletion(1 * time.Second)
}
return
case <-b.shutdownCtx.Done():
cancelLoad()
if lw != nil {
lw.waitForCompletion(1 * time.Second)
}
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
if resp.err != nil {
SendError(w, req, resp.err)
return
}
resp.handleFunc(w, req)
}
+863
View File
@@ -0,0 +1,863 @@
package router
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
// and never logs. It lets the base-router tests cover shared run-loop
// behaviour without dragging in either real router's eviction rules.
type stubPlanner struct {
evict map[string][]string
}
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
if s.evict == nil {
return nil
}
return s.evict[target]
}
func (s *stubPlanner) OnSwapStart(string) {}
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
t.Helper()
conf := config.Config{HealthCheckTimeout: 5}
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
b.testProcessed = make(chan struct{}, 64)
go b.run()
t.Cleanup(func() {
if !b.shuttingDown.Load() {
_ = b.Shutdown(time.Second)
}
})
return b
}
func TestBaseRouter_RunningModels(t *testing.T) {
ready := newFakeProcess("ready")
ready.markReady()
starting := newFakeProcess("starting")
starting.setState(process.StateStarting)
stopped := newFakeProcess("stopped")
b := newTestBase(t, map[string]process.Process{
"ready": ready, "starting": starting, "stopped": stopped,
}, &stubPlanner{})
running := b.RunningModels()
if len(running) != 2 {
t.Fatalf("running=%v want 2 entries", running)
}
if running["ready"] != process.StateReady {
t.Errorf("ready state=%q want ready", running["ready"])
}
if running["starting"] != process.StateStarting {
t.Errorf("starting state=%q want starting", running["starting"])
}
if _, ok := running["stopped"]; ok {
t.Errorf("stopped process should be excluded from RunningModels")
}
}
func TestBaseRouter_UnloadAll(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
c := newFakeProcess("c")
c.markReady()
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
b.Unload(time.Second)
if a.State() != process.StateStopped || c.State() != process.StateStopped {
t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State())
}
}
func TestBaseRouter_UnloadSpecificModel(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
c := newFakeProcess("c")
c.markReady()
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
b.Unload(time.Second, "a")
if a.State() != process.StateStopped {
t.Errorf("a should be stopped, got %q", a.State())
}
if c.State() != process.StateReady {
t.Errorf("c should remain ready, got %q", c.State())
}
}
// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its
// Stop calls concurrently rather than stopping each process serially. Each
// fakeProcess.Stop is pinned via stopBlock; the test only releases them
// after observing every stopStarted, proving all three Stops were in
// flight simultaneously.
func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
a.stopBlock = make(chan struct{})
pb := newFakeProcess("b")
pb.markReady()
pb.stopBlock = make(chan struct{})
pc := newFakeProcess("c")
pc.markReady()
pc.stopBlock = make(chan struct{})
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{})
unloadDone := make(chan struct{})
go func() {
b.Unload(time.Second, "a", "b", "c")
close(unloadDone)
}()
// All three Stop calls must start before any of them are allowed to
// complete. If Unload was serial, only one stopStarted would fire
// until we released its stopBlock, and this would deadlock.
for _, p := range []*fakeProcess{a, pb, pc} {
select {
case <-p.stopStarted:
case <-time.After(2 * time.Second):
t.Fatalf("Stop on %s never started — Unload is not parallel", p.id)
}
}
// Release them; Unload should now return.
close(a.stopBlock)
close(pb.stopBlock)
close(pc.stopBlock)
select {
case <-unloadDone:
case <-time.After(2 * time.Second):
t.Fatal("Unload did not return after stops released")
}
for _, p := range []*fakeProcess{a, pb, pc} {
if p.State() != process.StateStopped {
t.Errorf("%s state=%q want stopped", p.id, p.State())
}
if got := p.stopCalls.Load(); got != 1 {
t.Errorf("%s stopCalls=%d want 1", p.id, got)
}
}
}
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
// rejoins router state: a request whose swap to the unloaded model is
// still in progress receives an error, instead of being abandoned
// against a process that's about to vanish.
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false: the swap parks on WaitReady so we can interrupt
// it with Unload before it completes.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
b.ServeHTTP(w, newRequest("a"))
close(done)
}()
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
<-a.runStarted
b.Unload(time.Second, "a")
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("ServeHTTP did not return after Unload")
}
if w.Code == http.StatusOK {
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
}
if a.State() != process.StateStopped {
t.Errorf("a state=%q want stopped", a.State())
}
}
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
// for an unloaded model receive an error rather than sitting forever in
// the queue against state the router no longer maintains.
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
// Loading B evicts A — so a request for B while A is loading queues.
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
// r2 for B collides with A's in-flight swap and queues.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
// Unload B — r2 (queued, targeting B) must be released with an error.
b.Unload(time.Second, "b")
select {
case <-done2:
case <-time.After(2 * time.Second):
t.Fatal("queued B request did not return after Unload(b)")
}
if w2.Code == http.StatusOK {
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
}
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
}
// Release r1 so the test cleans up cleanly.
a.markReady()
select {
case <-done1:
case <-time.After(2 * time.Second):
t.Fatal("r1 did not complete after a.markReady")
}
}
func TestBaseRouter_FastPath(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.serveCalls.Load(); got != 1 {
t.Errorf("serveCalls=%d want 1", got)
}
if got := a.runCalls.Load(); got != 0 {
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
}
}
func TestBaseRouter_OnDemandStart(t *testing.T) {
a := newFakeProcess("a")
a.autoReady = true
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.runCalls.Load(); got != 1 {
t.Errorf("runCalls=%d want 1", got)
}
if got := a.serveCalls.Load(); got != 1 {
t.Errorf("serveCalls=%d want 1", got)
}
}
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false so the swap parks on WaitReady until we release it.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
const N = 5
var wg sync.WaitGroup
codes := make([]int, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
codes[i] = w.Code
}(i)
}
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
<-a.runStarted // swap goroutine reached Run()
a.markReady()
wg.Wait()
for i, c := range codes {
if c != http.StatusOK {
t.Errorf("request %d: status=%d", i, c)
}
}
if got := a.runCalls.Load(); got != 1 {
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
}
if got := a.serveCalls.Load(); got != N {
t.Errorf("serveCalls=%d want %d", got, N)
}
}
func TestBaseRouter_ContextCancel(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false so swap parks forever until we mark ready.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
ctx, cancel := context.WithCancel(context.Background())
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequestCtx(ctx, "a"))
close(done1)
}()
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("a"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap
<-a.runStarted
cancel()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("cancelled ServeHTTP did not return after ctx cancel")
}
a.markReady()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("non-cancelled ServeHTTP did not complete after swap")
}
if w2.Code != http.StatusOK {
t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String())
}
}
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
a := newFakeProcess("a")
pa := newFakeProcess("b")
// Loading b must stop a.
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
// First request starts a swap to A; A's autoReady=false so it parks.
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
// Second request for B should queue while A's swap is in flight.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
if got := pa.runCalls.Load(); got != 0 {
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
}
// Release A's swap. B's swap should then run.
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
<-pa.runStarted
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("A request did not complete")
}
pa.markReady()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("queued B request did not complete after A's swap")
}
if w2.Code != http.StatusOK {
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
}
}
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
// second request for each model rides the fast path — either joining the
// active swap, or being pulled out of the queue when handleSwapDone promotes
// the next model.
func TestBaseRouter_QueueCollation(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
pc := newFakeProcess("c")
// Each model evicts the other two so all swaps are mutually exclusive.
planner := &stubPlanner{evict: map[string][]string{
"a": {"b", "c"},
"b": {"a", "c"},
"c": {"a", "b"},
}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
var (
completedMu sync.Mutex
completed []string
)
record := func(id string) {
completedMu.Lock()
defer completedMu.Unlock()
completed = append(completed, id)
}
ids := []string{"a", "b", "c", "a", "b", "c"}
var wg sync.WaitGroup
for _, id := range ids {
id := id
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest(id))
if w.Code != http.StatusOK {
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
return
}
record(id)
}()
// Wait for run() to absorb this request before launching the next,
// so handlerCh receives them in launch order.
waitProcessed(t, b.testProcessed, 1)
}
// All 6 are now parked in run()'s waiters/queue. Release each swap in
// sequence, waiting deterministically for each promotion to fire.
<-a.runStarted
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
<-pb.runStarted
pb.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
<-pc.runStarted
pc.markReady()
wg.Wait()
if got := len(completed); got != 6 {
t.Fatalf("completed=%v want 6", completed)
}
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
// but waiter goroutines may be scheduled in any order after their respond
// channel fires, so completion order isn't deterministic. Per-model counts
// (combined with the runCalls checks below) are sufficient to prove queue
// collation collapsed each pair into a single swap.
aDone, bDone, cDone := 0, 0, 0
for _, id := range completed {
switch id {
case "a":
aDone++
case "b":
bDone++
case "c":
cDone++
}
}
if aDone != 2 || bDone != 2 || cDone != 2 {
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
}
// Single swap per model — the second request for each must have ridden
// the fast path (joined active swap or joined a queued sibling), not
// triggered an extra Run.
if got := a.runCalls.Load(); got != 1 {
t.Errorf("a.runCalls=%d want 1", got)
}
if got := pb.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
if got := pc.runCalls.Load(); got != 1 {
t.Errorf("c.runCalls=%d want 1", got)
}
}
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
// before either process is marked ready.
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
// Empty evict sets for both: they can load in parallel.
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
// Both swaps must have reached Run() before either is marked ready —
// proves they ran in parallel rather than serializing.
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("request A did not complete")
}
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("request B did not complete")
}
if w1.Code != http.StatusOK {
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
}
if w2.Code != http.StatusOK {
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
}
}
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
// unblocks every queued request that no longer collides — they all start in
// parallel rather than one-per-completion.
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
pc := newFakeProcess("c")
// A's swap evicts both B and C, so B and C must queue. Once A finishes
// B and C themselves have empty evict sets, so they can start together.
planner := &stubPlanner{evict: map[string][]string{
"a": {"b", "c"},
}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
// B and C arrive while A is loading. evict_b and evict_c are empty,
// but collidesWith returns true because they appear in A's evict set.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
w3 := httptest.NewRecorder()
done3 := make(chan struct{})
go func() {
b.ServeHTTP(w3, newRequest("c"))
close(done3)
}()
waitProcessed(t, b.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started early: runCalls=%d", got)
}
if got := pc.runCalls.Load(); got != 0 {
t.Errorf("c started early: runCalls=%d", got)
}
// Release A. The swapDone handler should drain the queue and start
// both B and C in parallel.
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
<-pb.runStarted
<-pc.runStarted
pb.markReady()
pc.markReady()
for i, ch := range []chan struct{}{done1, done2, done3} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
}
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
// the shutdown error to every waiter on every active swap AND to every
// queued request.
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
pc := newFakeProcess("c")
// a and b load in parallel (empty evicts). c collides with both.
planner := &stubPlanner{evict: map[string][]string{
"c": {"a", "b"},
}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
const waitersPer = 2
var wg sync.WaitGroup
codes := make([]int, 0, 2*waitersPer+1)
var codesMu sync.Mutex
record := func(code int) {
codesMu.Lock()
codes = append(codes, code)
codesMu.Unlock()
}
launch := func(model string) {
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest(model))
record(w.Code)
}()
}
// Active swaps for a and b, each with 2 waiters.
for i := 0; i < waitersPer; i++ {
launch("a")
waitProcessed(t, b.testProcessed, 1)
}
for i := 0; i < waitersPer; i++ {
launch("b")
waitProcessed(t, b.testProcessed, 1)
}
// c collides with both → queues.
launch("c")
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
<-pb.runStarted
if err := b.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
wg.Wait()
codesMu.Lock()
defer codesMu.Unlock()
if len(codes) != 2*waitersPer+1 {
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
}
for i, c := range codes {
if c == http.StatusOK {
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
}
}
}
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
// is not stopped to satisfy another model's swap while it is still handling
// a request.
//
// Sequence:
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
// intent collides with A.
// 4. r1 released — A finishes r1, then r3 is served by A.
// 5. B's swap then proceeds; r2 is served by B.
//
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
a := newFakeProcess("a")
// autoReady left false: we markReady manually after observing runStarted,
// so autoReady's setState(Ready) cannot race with a later Stop and leave
// A in Ready, masking the bug.
a.serveBlock = make(chan struct{})
pb := newFakeProcess("b")
// Same reasoning for B: park its swap on WaitReady until we choose.
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
<-a.runStarted
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone for A
<-a.serveStarted
// r2 — would evict A. A must not be stopped while r1 is in flight.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
// r3 — another request for A, arrives behind r2 and queues because
// B's swap intent (which evicts A) is recorded as active.
w3 := httptest.NewRecorder()
done3 := make(chan struct{})
go func() {
b.ServeHTTP(w3, newRequest("a"))
close(done3)
}()
waitProcessed(t, b.testProcessed, 1)
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
// The router must hold off B's swap until A has drained.
close(a.serveBlock)
select {
case <-done1:
case <-time.After(2 * time.Second):
t.Fatal("r1 did not complete after serveBlock release")
}
// Wait for B.Run before marking it ready: markReady before Run would
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
// implementation B's swap only starts after A has drained; in the
// current implementation it has already started — either way runStarted
// fires.
<-pb.runStarted
pb.markReady()
select {
case <-done2:
case <-time.After(2 * time.Second):
t.Fatal("r2 did not complete after B marked ready")
}
select {
case <-done3:
case <-time.After(2 * time.Second):
t.Fatal("r3 did not complete")
}
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
}
if w1.Body.String() != "ok:a" {
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
}
if w3.Body.String() != "ok:a" {
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
}
if w2.Body.String() != "ok:b" {
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
}
if a.stoppedWhileServing.Load() {
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
}
}
func TestBaseRouter_ModelNotFound(t *testing.T) {
a := newFakeProcess("a")
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("unknown"))
if w.Code != http.StatusNotFound {
t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String())
}
}
func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
pb := newFakeProcess("b")
pb.markReady()
go pb.Run(0)
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
if err := b.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := pb.stopCalls.Load(); got != 1 {
t.Errorf("b.stopCalls=%d want 1", got)
}
// Subsequent ServeHTTP should report 5xx.
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable {
t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String())
}
// Second Shutdown should report already in progress.
if err := b.Shutdown(0); err == nil {
t.Errorf("second Shutdown returned nil, want error")
}
}
+110
View File
@@ -0,0 +1,110 @@
package router
import (
"fmt"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type Group struct {
*baseRouter
}
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
modelToGroup := make(map[string]string)
for gid, gcfg := range conf.Groups {
for _, mid := range gcfg.Members {
if existing, dup := modelToGroup[mid]; dup {
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
}
modelToGroup[mid] = gid
}
}
planner := &groupPlanner{
config: conf,
modelToGroup: modelToGroup,
}
processes := make(map[string]process.Process, len(modelToGroup))
base := newBaseRouter("group", conf, processes, planner, proxylog)
planner.processes = processes
for mid := range modelToGroup {
modelCfg, _, ok := conf.FindConfig(mid)
if !ok {
base.shutdownFn()
return nil, fmt.Errorf("no model config for %q", mid)
}
procLog := logmon.NewWriter(upstreamlog)
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
if err != nil {
base.shutdownFn()
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
}
processes[mid] = p
}
g := &Group{baseRouter: base}
go base.run()
return g, nil
}
// groupPlanner decides evictions from static group configuration.
//
// Same-group siblings are stopped when the group has swap=true. Cross-group
// members are stopped only when the target's group is exclusive; loading a
// model from a non-exclusive group leaves running exclusive groups alone,
// matching the gotcha in the original ProcessGroup behaviour.
type groupPlanner struct {
config config.Config
modelToGroup map[string]string
processes map[string]process.Process
}
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
tg := p.modelToGroup[target]
tgCfg := p.config.Groups[tg]
seen := make(map[string]struct{})
var result []string
consider := func(mID string) {
if mID == target {
return
}
if _, dup := seen[mID]; dup {
return
}
og := p.modelToGroup[mID]
switch {
case og == tg && tgCfg.Swap:
seen[mID] = struct{}{}
result = append(result, mID)
// the previous ProcessGroup behaviour did not unload exclusive groups
// when loading a non-exclusive model. This maintains that gotcha
// for backwards compatibility. The newer swap matrix approach does not
// have this issue.
case og != tg && tgCfg.Exclusive:
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
seen[mID] = struct{}{}
result = append(result, mID)
}
}
}
for mID, proc := range p.processes {
st := proc.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
consider(mID)
}
for _, mID := range alsoRunning {
consider(mID)
}
return result
}
func (p *groupPlanner) OnSwapStart(target string) {}
+331
View File
@@ -0,0 +1,331 @@
package router
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// newTestGroup builds a Group directly from the supplied processes and config,
// bypassing NewGroup's call to process.New.
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
t.Helper()
modelToGroup := make(map[string]string)
for gid, gcfg := range conf.Groups {
for _, mid := range gcfg.Members {
modelToGroup[mid] = gid
}
}
planner := &groupPlanner{
config: conf,
modelToGroup: modelToGroup,
processes: processes,
}
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
base.testProcessed = make(chan struct{}, 64)
g := &Group{baseRouter: base}
go base.run()
t.Cleanup(func() {
if !g.shuttingDown.Load() {
_ = g.Shutdown(time.Second)
}
})
return g
}
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
conf := config.Config{
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Members: []string{"a"}},
"g2": {Swap: true, Members: []string{"a"}},
},
Models: map[string]config.ModelConfig{
"a": {},
},
}
log := logmon.NewWriter(io.Discard)
if _, err := NewGroup(conf, log, log); err == nil {
t.Fatalf("expected error for duplicate membership")
}
}
func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0) // park a Run goroutine so Stop has something to release
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
if got := b.serveCalls.Load(); got != 1 {
t.Errorf("b.serveCalls=%d want 1", got)
}
}
func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
func TestGroup_CrossGroupExclusive(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got)
}
}
// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for
// models in distinct non-exclusive groups load in parallel rather than
// serializing through the router's run loop.
func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
g.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, g.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
g.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, g.testProcessed, 1)
// Both groups load concurrently — both must reach Run() before either is
// marked ready. If the router still serialised, only one would proceed.
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
}
}
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
// (Swap=true) serialise even when both arrive while neither has reached
// StateStarting yet — the alsoRunning hint to the planner closes that race.
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
g.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, g.testProcessed, 1)
// Request B arrives before A transitions to StateStarting in the process
// state machine. Without the alsoRunning hint, the planner would not see
// A as running, and B would start in parallel, violating Swap=true.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
g.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, g.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started in parallel: runCalls=%d want 0", got)
}
<-a.runStarted
a.markReady()
waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted
<-pb.runStarted
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
}
}
// TestGroup_PersistentNotEvicted verifies that a group with persistent=true
// is never evicted when another exclusive group starts loading. The running
// model in the persistent group stays alive alongside the new one.
func TestGroup_PersistentNotEvicted(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got)
}
if a.State() != process.StateStarting && a.State() != process.StateReady {
t.Errorf("a state=%s want still running", a.State())
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible
// gotcha from the original ProcessGroup: when a model in a non-exclusive group
// is loaded, any running exclusive group keeps running. The two coexist.
func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got)
}
if a.State() != process.StateStarting && a.State() != process.StateReady {
t.Errorf("a state=%s want still running", a.State())
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
+205
View File
@@ -0,0 +1,205 @@
package router
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// fakeProcess is an in-memory implementation of process.Process used to drive
// the routers through their state machine without spawning real upstreams.
type fakeProcess struct {
id string
mu sync.Mutex
state process.ProcessState
readyCh chan struct{}
stopCh chan struct{}
runStarted chan struct{} // closed on the first Run call
stopStarted chan struct{} // closed on the first Stop call
autoReady bool
// serveBlock, when non-nil, makes ServeHTTP receive from it before
// writing its response. Tests use this to hold a request in-flight.
// Closing the channel releases every blocked ServeHTTP caller.
serveBlock chan struct{}
// serveStarted is closed on the first ServeHTTP entry, letting tests
// wait deterministically for the handler to begin executing.
serveStarted chan struct{}
// stopBlock, when non-nil, makes Stop receive from it (after signalling
// stopStarted) before completing. Tests use this to prove that several
// Stop calls can be in flight simultaneously.
stopBlock chan struct{}
runCalls atomic.Int32
stopCalls atomic.Int32
serveCalls atomic.Int32
// inFlightServe counts ServeHTTP calls currently inside the handler.
// stoppedWhileServing flips true if Stop is ever called while that
// counter is non-zero — a direct, race-free observation of the
// "swap mid-request" anti-property.
inFlightServe atomic.Int32
stoppedWhileServing atomic.Bool
}
func newFakeProcess(id string) *fakeProcess {
return &fakeProcess{
id: id,
state: process.StateStopped,
readyCh: make(chan struct{}),
stopCh: make(chan struct{}),
runStarted: make(chan struct{}),
stopStarted: make(chan struct{}),
serveStarted: make(chan struct{}),
}
}
func (f *fakeProcess) setState(s process.ProcessState) {
f.mu.Lock()
defer f.mu.Unlock()
f.state = s
if s == process.StateReady {
select {
case <-f.readyCh:
default:
close(f.readyCh)
}
}
}
func (f *fakeProcess) State() process.ProcessState {
f.mu.Lock()
defer f.mu.Unlock()
return f.state
}
func (f *fakeProcess) markReady() { f.setState(process.StateReady) }
func (f *fakeProcess) Run(_ time.Duration) error {
f.runCalls.Add(1)
f.mu.Lock()
if f.state != process.StateStopped {
s := f.state
f.mu.Unlock()
return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s)
}
f.state = process.StateStarting
sc := f.stopCh
select {
case <-f.runStarted:
default:
close(f.runStarted)
}
f.mu.Unlock()
if f.autoReady {
f.setState(process.StateReady)
}
<-sc
return nil
}
func (f *fakeProcess) Stop(_ time.Duration) error {
f.stopCalls.Add(1)
if f.inFlightServe.Load() > 0 {
f.stoppedWhileServing.Store(true)
}
f.mu.Lock()
select {
case <-f.stopStarted:
default:
close(f.stopStarted)
}
f.mu.Unlock()
// Test hook: hold Stop here so the test can prove multiple Stops are
// in flight at the same time before any of them complete.
if f.stopBlock != nil {
<-f.stopBlock
}
f.mu.Lock()
defer f.mu.Unlock()
if f.state == process.StateStopped {
return nil
}
f.state = process.StateStopped
select {
case <-f.stopCh:
default:
close(f.stopCh)
}
return nil
}
func (f *fakeProcess) WaitReady(ctx context.Context) error {
f.mu.Lock()
if f.state == process.StateReady {
f.mu.Unlock()
return nil
}
rc := f.readyCh
f.mu.Unlock()
select {
case <-rc:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) }
func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
f.serveCalls.Add(1)
f.inFlightServe.Add(1)
defer f.inFlightServe.Add(-1)
f.mu.Lock()
select {
case <-f.serveStarted:
default:
close(f.serveStarted)
}
f.mu.Unlock()
if f.serveBlock != nil {
<-f.serveBlock
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "ok:%s", f.id)
}
// waitProcessed drains n events from ch, fataling on timeout. One event fires
// per handlerReq or swapDone fully absorbed by run().
func waitProcessed(t *testing.T, ch chan struct{}, n int) {
t.Helper()
for i := 0; i < n; i++ {
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatalf("waitProcessed: only %d/%d events received", i, n)
}
}
}
func newRequest(model string) *http.Request {
body := fmt.Sprintf(`{"model":%q}`, model)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
return r
}
func newRequestCtx(ctx context.Context, model string) *http.Request {
return newRequest(model).WithContext(ctx)
}
+249
View File
@@ -0,0 +1,249 @@
package router
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
var loadingPaths = []string{
"/v1/chat/completions",
}
func isLoadingPath(path string) bool {
for _, p := range loadingPaths {
if strings.HasPrefix(path, p) {
return true
}
}
return false
}
type loadingWriter struct {
hasWritten bool
writer http.ResponseWriter
req *http.Request
ctx context.Context
logger *logmon.Monitor
modelName string
startTime time.Time
pendingMu sync.Mutex
pendingUpdate string
// closed by start when the goroutine finishes (after cleanup messages)
done chan struct{}
// test-only: closed when start enters its loop
loopStarted chan struct{}
// test-only: override the 1s tick interval
tickDuration time.Duration
// test-only: override character streaming speed (0 = no delay)
charPerSecond float64
}
func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter {
s := &loadingWriter{
writer: w,
req: req,
ctx: req.Context(),
logger: logger,
modelName: modelName,
startTime: time.Now(),
tickDuration: 750 * time.Millisecond,
charPerSecond: 75,
}
s.Header().Set("Content-Type", "text/event-stream")
s.Header().Set("Cache-Control", "no-cache")
s.Header().Set("Connection", "keep-alive")
s.WriteHeader(http.StatusOK)
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName))
return s
}
func (s *loadingWriter) setUpdate(msg string) {
s.pendingMu.Lock()
s.pendingUpdate = msg
s.pendingMu.Unlock()
}
func (s *loadingWriter) start(ctx context.Context) {
s.done = make(chan struct{})
defer close(s.done)
defer func() {
// Skip cleanup writes if the client disconnected — the connection
// is being torn down and flushing against it will panic.
if s.ctx.Err() != nil {
return
}
duration := time.Since(s.startTime)
s.sendData("\n")
s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Time{}
ticker := time.NewTicker(s.tickDuration)
defer ticker.Stop()
if s.loopStarted != nil {
close(s.loopStarted)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.pendingMu.Lock()
update := s.pendingUpdate
s.pendingUpdate = ""
s.pendingMu.Unlock()
if update != "" {
s.sendData("\n")
s.sendInline(update)
s.sendData(" ")
lastRemarkTime = time.Now()
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendData("\n")
s.sendInline(remark)
s.sendData(" ")
lastRemarkTime = time.Now()
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool {
if s.done == nil {
return true
}
select {
case <-s.done:
return true
case <-time.After(timeout):
return false
}
}
func (s *loadingWriter) sendInline(text string) {
chunkSize := 10
if s.charPerSecond > 0 {
chunkSize = max(3, int(s.charPerSecond)/15)
}
runes := []rune(text)
for i := 0; i < len(runes); {
select {
case <-s.ctx.Done():
return
default:
}
end := i + chunkSize
if end > len(runes) {
end = len(runes)
}
chunk := string(runes[i:end])
s.sendData(chunk)
i = end
if i < len(runes) && s.charPerSecond > 0 {
time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond))
}
}
}
func (s *loadingWriter) sendLine(line string) {
if line == "" {
s.sendData("\n")
return
}
s.sendInline(line)
s.sendData("\n")
}
func (s *loadingWriter) sendData(data string) {
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err)
return
}
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
if err != nil {
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
return
}
s.Flush()
}
func (s *loadingWriter) Header() http.Header {
return s.writer.Header()
}
func (s *loadingWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *loadingWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
func (s *loadingWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+133
View File
@@ -0,0 +1,133 @@
package router
var loadingRemarks = []string{
"Still faster than your last standup meeting",
"Reticulating splines",
"Waking up the hamsters",
"Teaching the model manners",
"Convincing the GPU to participate",
"Loading weights (they're heavy)",
"Please enjoy this elevator music in your head",
"Pretending to be productive",
"Reading the entire internet, page by page",
"Staring at the abyss, the abyss is buffering",
"Applying layer after layer of disembodied cognition",
"Remembering everything it forgot during quantization",
"Counting to 405 billion, one parameter at a time",
"Summoning the stochastic parroting",
"Hold on, the GPU is questioning its existence",
"Deciding which facts to hallucinate today",
"Untangling the transformer spaghetti",
"Warming up the token soup",
"Your prompt is in a queue, behind 7 billion other thoughts",
"Running `sudo apt-get install intelligence`",
"Defragmenting the latent space",
"Polishing each matrix multiplication by hand",
"Whispering sweet nothings to the attention heads",
"Aligning with human values, one reluctant epoch at a time",
"The model is thinking about what it's about to think about",
"Loading... and by loading we mean making you wait",
"Spinning up the cloud GPU, please be patient while we burn your credits",
"Applying duct tape to the context window",
"Bribing the GPU scheduler for a timeslice",
"Would you like to hear a fun fact while we load? Too bad.",
"Hot swapping your sanity for an LLM",
"Compressing optimism into FP16",
"Ignoring 90% of the attention to save you 50% of the time",
"Counting the exact same thing three times just to be sure",
"Sorry, the inference you have reached is not in service",
"Rotating the positional encodings counterclockwise for good luck",
"Your call is very important to us. Please continue to hold.",
"Unpacking the blobs. All 300GB of them.",
"Initializing the thing that initializes the other thing",
"Converting electricity into existential dread",
"Flattening the curve... wait, the tensor. Flattening the tensor.",
"Fetching the fetch of a fetch, callback hell edition",
"The GPU is at 100%. The fan is now a helicopter.",
"Baking the weights at 350° for a golden-brown inference",
"Recalibrating the confidence of things it's still wrong about",
"Have you tried turning it off and on again? No? Good, wait here.",
"Simulating deep thought by pausing dramatically",
"Loading the model that knows more than you but still can't count r's in 'strawberry'",
"Convincing CUDA to cooperate. This may take a while.",
"VRAM: 23.9GB used of 24GB. Living on the edge.",
"Processing your request with the urgency of a DMV employee",
"This model was trained on the entire internet, including that embarrassing blog you wrote in 2008",
"Dispatching tokens through a series of increasingly confused matrix multiplies",
"Gently lowering your expectations",
"Applying softmax to our feelings about this load time",
"Autoregressively generating disappointment, one token at a time",
"The magic is happening. Somewhere. Probably.",
"Synchronizing the parallel processes that run in parallel but really don't",
"Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.",
"Loading... just like it said 30 seconds ago. And will say 30 seconds from now.",
"Pre-warming the cache so the first query is only slightly slower than the rest",
"Have you considered that maybe your question wasn't worth all this compute?",
"Downloading more RAM (no, really, we're mmap-ing the weights)",
"Translating your prompt into math it barely understands",
"Estimating your time remaining with 0% accuracy",
"Buffering enthusiasm",
"Model is loading. Go make some coffee. Or a three-course meal.",
"Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'",
"Polling for readiness in a loop that would make your CS professor weep",
"Performing percussive maintenance on the attention mechanism",
"This loading screen is singlehandedly reversing climate progress",
"Decompressing the hopes and dreams of thousands of underpaid labelers",
"Filling the key-value cache with the ghost of prompts past",
"Currently at step 3 of 9,742 of loading. We'll get there. Eventually.",
"If you stare at the spinner, it spins slower. It's science.",
"Multiplying matricies with the enthusiasm of a teenager doing chores",
"Applying `torch.nap()` until the model feels refreshed",
"Reacquainting the model with the concept of 'facts' it forgot during fine-tuning",
"Sorry for the wait. No, wait, we're not actually sorry.",
"Your GPU is now a space heater with a side hustle in linear algebra",
"Allocating memory like a billionaire allocates tax avoidance strategies",
"The model saw \"As an AI language model\" and won't stop saying it now",
"Installing dependencies you didn't know existed and will never use again",
"Re-reading 'Attention Is All You Need' for the 400th time",
"Convincing the embedding layer that context is overrated",
"Manually untangling the residual connections with a tiny comb",
"On hold with the cloud provider trying to explain why 8 H100s isn't enough",
"Adjusting temperatures: model is 0.7, server room is 104°F",
"Please hold while we justify this electricity bill to accounting",
"Stacking decoder blocks like a Jenga tower at a LAN party",
"Compensating for your lack of patience with our lack of speed",
"This is a loading screen comment. Loading screens have comments now. Welcome to the future.",
"Processing the entire works of Shakespeare backwards just in case",
"The model is loading slower than your last `npm install`",
"Rehearsing plausible-sounding explanations for why it got everything wrong",
"Populating the context with filler while you wait for actual content",
"Optimizing for BLEU score, which definitely correlates with making you laugh",
"Generating an embedding for each and every letter of the alphabet, individually",
"Coming soon: llama-swap v2 with actual performance improvements. Probably.",
"Loading a model larger than your attention span",
"Performing a seance to invoke the spirit of Geoff Hinton",
"Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.",
"Converting all the internet's bad opinions into a surprisingly useful autocomplete",
"Laying down each layer with the care of a Michelin-starred pastry chef",
"Checking if the model still thinks birds are government drones. Yep.",
"Activating the neurons responsible for 'I cannot assist with that request'",
"This model was trained on the same internet that brought you Rickrolling. You're welcome.",
"Realigning the alignment so it aligns with the previous alignment",
"Running `nvidia-smi` and sighing heavily",
"If you close your eyes, the loading bar moves faster. Proven by science.",
"EULA said 'by using this software you agree to wait forever' and you clicked Accept",
"Zipping the GPUs to make them go faster",
"Padding the context window with existential padding",
"We could have used a smaller model but someone wanted 'quality'",
"Disentangling the latent space into something resembling coherence",
"Slow is smooth, smooth is fast, but this is just slow",
"Memory-mapping like it's a AAA title from 2012",
"Your patience has been tokenized and added to the training set. Thank you for your contribution.",
"Loading is CPU-bound and your CPU is busy regretting its life choices",
"Exploring the high-dimensional manifold of ways to say 'just a moment'",
"The model is experiencing a brief but intense moment of imposter syndrome",
"Initializing 7B parameters by rolling 7B 16-sided dice",
"Panic! at the disk I/O",
"Intelligence is loading... your definition of intelligence may vary",
"This model was distilled. Unlike your patience, which is evaporating.",
"Unzipping the model. It's a .gguf file, not a metaphor.",
"Running inference on the concept of 'soon' to estimate remaining time",
"Loading with all the speed of a government-funded IT project",
"A blank terminal is a terrible thing to waste. Here's a loading message instead.",
}
+328
View File
@@ -0,0 +1,328 @@
package router
import (
"bufio"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
}
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
t.Errorf("Cache-Control: want no-cache, got %q", cc)
}
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
t.Errorf("Connection: want keep-alive, got %q", conn)
}
body := w.Body.String()
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected SSE data: prefix, got: %s", body)
}
content := extractStreamedContent(body)
if !strings.Contains(content, "━━━━━\n") {
t.Errorf("missing separator in streamed content: %q", content)
}
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
t.Errorf("missing initial message in streamed content: %q", content)
}
}
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.WriteHeader(http.StatusCreated)
if w.Code != http.StatusOK {
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
}
}
func TestLoadingWriter_WritePassthrough(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.Write([]byte("hello"))
lw.Flush()
body := w.Body.String()
if !strings.Contains(body, "hello") {
t.Errorf("Write passthrough failed, body: %s", body)
}
}
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
if !strings.Contains(body, "Done!") {
t.Errorf("expected Done! message, body: %s", body)
}
}
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
lw.setUpdate("custom status message")
time.Sleep(50 * time.Millisecond)
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
content := extractStreamedContent(body)
if !strings.Contains(content, "custom status message") {
t.Errorf("expected setUpdate message in output, got: %q", content)
}
}
func TestLoadingWriter_SendDataFormat(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.sendData("hello world")
body := w.Body.String()
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
}
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected data: prefix, got: %s", body)
}
}
func TestLoadingWriter_SendLine(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.charPerSecond = 0
// Capture only the content from this sendLine call
before := w.Body.Len()
lw.sendLine("line content")
after := w.Body.Len()
chunkBody := w.Body.String()[before:after]
content := extractStreamedContent(chunkBody)
if content != "line content\n" {
t.Errorf("expected complete streamed line, got: %q", content)
}
}
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
lw.start(ctx)
close(done)
}()
<-lw.loopStarted
time.Sleep(50 * time.Millisecond)
cancel()
<-done
body := w.Body.String()
lines := countSSEMessages(body)
if lines < 2 {
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
}
}
func TestLoadingWriter_ReqStored(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if lw.req != req {
t.Fatal("req not stored")
}
}
func TestIsLoadingPath(t *testing.T) {
tests := []struct {
path string
want bool
}{
{"/v1/chat/completions", true},
{"/v1/chat/completions/extra", true},
{"/v1/completions", false},
{"/v1/embeddings", false},
{"/health", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := isLoadingPath(tt.path); got != tt.want {
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func TestExtractContext_Streaming_GET(t *testing.T) {
tests := []struct {
name string
query string
wantStreaming bool
}{
{"streaming true", "model=llama3&stream=true", true},
{"streaming false", "model=llama3&stream=false", false},
{"no stream param", "model=llama3", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
got, err := ExtractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if got.Streaming != tt.wantStreaming {
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
}
})
}
}
func TestExtractContext_Streaming_JSON(t *testing.T) {
tests := []struct {
name string
body string
wantStreaming bool
}{
{"streaming true", `{"model":"llama3","stream":true}`, true},
{"streaming false", `{"model":"llama3","stream":false}`, false},
{"no stream param", `{"model":"llama3"}`, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
got, err := ExtractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if got.Streaming != tt.wantStreaming {
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
}
})
}
}
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := ExtractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if !got.Streaming {
t.Error("Streaming should be true")
}
}
func countSSEMessages(s string) int {
scanner := bufio.NewScanner(strings.NewReader(s))
count := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
count++
}
}
return count
}
func extractStreamedContent(body string) string {
var result strings.Builder
scanner := bufio.NewScanner(strings.NewReader(body))
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonData := strings.TrimPrefix(line, "data: ")
var msg struct {
Choices []struct {
Delta struct {
ReasoningContent string `json:"reasoning_content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
continue
}
if len(msg.Choices) > 0 {
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
}
}
return result.String()
}
+100
View File
@@ -0,0 +1,100 @@
package router
import (
"fmt"
"sort"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type Matrix struct {
*baseRouter
}
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
if conf.Matrix == nil {
return nil, fmt.Errorf("matrix router requires a matrix configuration")
}
planner := &matrixPlanner{
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
logger: proxylog,
}
// Build a process for every model in the config. Any model can run alone
// even if it is not part of a set; this mirrors proxy.NewMatrix.
processes := make(map[string]process.Process, len(conf.Models))
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
planner.processes = processes
for mid, modelCfg := range conf.Models {
procLog := logmon.NewWriter(upstreamlog)
p, err := process.New(base.shutdownCtx, mid, modelCfg, procLog, proxylog)
if err != nil {
base.shutdownFn()
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
}
processes[mid] = p
}
r := &Matrix{baseRouter: base}
go base.run()
return r, nil
}
// matrixPlanner decides evictions by asking the matrix solver against the
// current running set.
type matrixPlanner struct {
solver *matrixSolver
processes map[string]process.Process
logger *logmon.Monitor
}
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
}
func (p *matrixPlanner) OnSwapStart(target string) {
running := p.runningModels()
result := p.solver.Solve(target, running)
switch {
case len(result.Evict) > 0:
p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
case len(running) == 0:
p.logger.Infof("matrix: model=%s starting (no models running)", target)
default:
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
}
}
func (p *matrixPlanner) runningModels() []string {
return p.runningSet(nil)
}
// runningSet returns the union of live processes (State != Stopped/Shutdown)
// and any extra IDs the baseRouter has already committed to loading but which
// the process state machine has not yet reflected.
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
seen := make(map[string]struct{}, len(p.processes))
var running []string
for id, proc := range p.processes {
st := proc.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
seen[id] = struct{}{}
running = append(running, id)
}
for _, id := range alsoRunning {
if _, dup := seen[id]; dup {
continue
}
seen[id] = struct{}{}
running = append(running, id)
}
sort.Strings(running)
return running
}
+132
View File
@@ -0,0 +1,132 @@
package router
import (
"slices"
"github.com/mostlygeek/llama-swap/internal/config"
)
// matrixSolver contains pure swap-decision logic with no Process dependencies.
// It is safe for concurrent reads after construction.
type matrixSolver struct {
expandedSets []config.ExpandedSet // all valid model combinations
evictCosts map[string]int // real model name -> eviction cost (default 1)
modelToSets map[string][]int // model name -> indices into expandedSets
}
func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver {
modelToSets := make(map[string][]int)
for i, es := range expandedSets {
for _, model := range es.Models {
modelToSets[model] = append(modelToSets[model], i)
}
}
return &matrixSolver{
expandedSets: expandedSets,
evictCosts: evictCosts,
modelToSets: modelToSets,
}
}
// solveResult describes what the solver decided.
type solveResult struct {
Evict []string // running models that must be stopped
TargetSet []string // the chosen set of models (for informational purposes)
SetName string // name of the chosen set
DSL string // original DSL expression for the chosen set
TotalCost int // total eviction cost
}
// Solve determines which models to evict when a model is requested.
//
// Algorithm:
// 1. If requestedModel is already running, no eviction needed.
// 2. Find all sets containing requestedModel.
// 3. If no sets found, the model runs alone; evict all running models.
// 4. For each candidate set, compute cost = sum of evict_costs for running
// models NOT in that set.
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
// 6. Return models to evict and the chosen set.
func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult {
if slices.Contains(runningModels, requestedModel) {
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
return solveResult{
TargetSet: runningModels,
SetName: setName,
DSL: dsl,
}
}
candidateIndices := s.modelToSets[requestedModel]
// Model not in any set: runs alone, evict everything.
if len(candidateIndices) == 0 {
evict := make([]string, len(runningModels))
copy(evict, runningModels)
return solveResult{
Evict: evict,
TargetSet: []string{requestedModel},
}
}
bestCost := -1
bestIdx := -1
for _, idx := range candidateIndices {
setModels := s.expandedSets[idx].Models
cost := 0
for _, running := range runningModels {
if !slices.Contains(setModels, running) {
cost += s.evictCost(running)
}
}
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
bestCost = cost
bestIdx = idx
}
}
chosen := s.expandedSets[bestIdx]
var evict []string
for _, running := range runningModels {
if !slices.Contains(chosen.Models, running) {
evict = append(evict, running)
}
}
return solveResult{
Evict: evict,
TargetSet: chosen.Models,
SetName: chosen.SetName,
DSL: chosen.DSL,
TotalCost: bestCost,
}
}
// findMatchingSet finds the expanded set that contains all running models.
// Returns the set name and DSL, or empty strings if no match.
func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
for _, idx := range s.modelToSets[requestedModel] {
set := s.expandedSets[idx]
allInSet := true
for _, m := range runningModels {
if !slices.Contains(set.Models, m) {
allInSet = false
break
}
}
if allInSet {
return set.SetName, set.DSL
}
}
return "", ""
}
func (s *matrixSolver) evictCost(model string) int {
if cost, ok := s.evictCosts[model]; ok {
return cost
}
return 1
}
+244
View File
@@ -0,0 +1,244 @@
package router
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// newTestMatrix builds a Matrix router from supplied processes, bypassing
// NewMatrix's call to process.New.
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
t.Helper()
logger := logmon.NewWriter(io.Discard)
planner := &matrixPlanner{
solver: newMatrixSolver(expanded, evictCosts),
processes: processes,
logger: logger,
}
base := newBaseRouter("matrix", conf, processes, planner, logger)
base.testProcessed = make(chan struct{}, 64)
r := &Matrix{baseRouter: base}
go base.run()
t.Cleanup(func() {
if !r.shuttingDown.Load() {
_ = r.Shutdown(time.Second)
}
})
return r
}
func baseMatrixConfig() config.Config {
return config.Config{
HealthCheckTimeout: 5,
Matrix: &config.MatrixConfig{},
}
}
// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers
// eviction of running models that are not in any shared set with it.
func TestMatrix_SwapEvictsConflicting(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0) // park a Run goroutine so Stop has something to release
b := newFakeProcess("b")
b.autoReady = true
// Two single-model sets: a and b never coexist, so loading b must evict a.
expanded := []config.ExpandedSet{
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
r.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestMatrix_CoexistInSet verifies that a model is not evicted when the target
// shares a set with it (the fast path applies if the target is already ready).
func TestMatrix_CoexistInSet(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
// Both fit in s_ab, so b's swap should not stop a.
expanded := []config.ExpandedSet{
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
r.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestMatrix_CoexistingSetParallel verifies that two models that share an
// expanded set load in parallel — the solver returns empty Evict for both,
// the collision predicate clears them, and both swaps run together.
func TestMatrix_CoexistingSetParallel(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
expanded := []config.ExpandedSet{
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
r.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, r.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
r.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, r.testProcessed, 1)
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got)
}
}
// TestMatrix_IncompatibleQueues verifies that the second request for a model
// that cannot coexist with the in-flight first model queues until the first
// completes, and then evicts it. This exercises the alsoRunning hint via the
// matrix solver's union into runningSet.
func TestMatrix_IncompatibleQueues(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
expanded := []config.ExpandedSet{
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
r.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, r.testProcessed, 1)
// B arrives before A transitions to StateStarting. The solver sees A via
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
r.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, r.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started in parallel: runCalls=%d want 0", got)
}
<-a.runStarted
a.markReady()
waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a
<-pb.runStarted
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
}
}
// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule:
// when multiple candidate sets have equal eviction cost, the earlier-defined
// set wins.
func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) {
expanded := []config.ExpandedSet{
{SetName: "first", DSL: "a & b", Models: []string{"a", "b"}},
{SetName: "second", DSL: "a & c", Models: []string{"a", "c"}},
}
s := newMatrixSolver(expanded, nil)
// No models running, request "a": both sets have cost 0 and contain a.
// Definition order: "first" wins.
result := s.Solve("a", nil)
if result.SetName != "first" {
t.Errorf("SetName=%q want %q", result.SetName, "first")
}
}
// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer
// the solver toward a cheaper set.
func TestMatrixSolver_EvictCostsPreferred(t *testing.T) {
// b is expensive to evict; c is cheap. Request "a" with both b and c
// running. The solver should pick the set that keeps b.
expanded := []config.ExpandedSet{
{SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10)
{SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1)
}
s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1})
result := s.Solve("a", []string{"b", "c"})
if result.SetName != "a_with_b" {
t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b")
}
if len(result.Evict) != 1 || result.Evict[0] != "c" {
t.Errorf("Evict=%v want [c]", result.Evict)
}
}
+188
View File
@@ -0,0 +1,188 @@
package router
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
type peerMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type Peer struct {
cfg config.Config
logger *logmon.Monitor
peers map[string]*peerMember
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
inflight sync.WaitGroup
}
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
peers := cfg.Peers
modelMap := make(map[string]*peerMember)
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
for _, peerID := range peerIDs {
peer := peers[peerID]
peerTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
}
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
reverseProxy.Transport = peerTransport
originalDirector := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = req.URL.Host
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
for _, modelID := range peer.Models {
if _, found := modelMap[modelID]; found {
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
modelMap[modelID] = pp
}
}
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
return &Peer{
cfg: cfg,
logger: logger,
peers: modelMap,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
}, nil
}
func (r *Peer) Handles(model string) bool {
_, ok := r.peers[model]
return ok
}
func (r *Peer) Shutdown(timeout time.Duration) error {
if !r.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("shutdown already in progress")
}
if timeout == 0 {
r.shutdownFn()
r.inflight.Wait()
return nil
}
done := make(chan struct{})
go func() {
r.inflight.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-time.After(timeout):
r.shutdownFn()
r.inflight.Wait()
return fmt.Errorf("peer shutdown timed out after %v", timeout)
}
}
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.shuttingDown.Load() {
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
return
}
r.inflight.Add(1)
defer r.inflight.Done()
data, err := FetchContext(req, r.cfg)
if err != nil {
SendError(w, req, err)
return
}
pp, found := r.peers[data.ModelID]
if !found {
r.logger.Warnf("peer model not found: %s", data.ModelID)
SendError(w, req, ErrNoPeerModelFound)
return
}
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
if pp.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
req.Header.Set("x-api-key", pp.apiKey)
}
// Cancel the proxy request when the client disconnects or shutdown times out.
// AfterFunc links both parent contexts to our child without a goroutine leak.
ctx, cancel := context.WithCancel(context.Background())
stopReq := context.AfterFunc(req.Context(), cancel)
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
req = req.WithContext(ctx)
pp.reverseProxy.ServeHTTP(w, req)
stopShutdown()
stopReq()
cancel()
}
+611
View File
@@ -0,0 +1,611 @@
package router
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
var testLogger = logmon.NewWriter(os.Stdout)
func init() {
testLogger.SetLogLevel(logmon.LevelWarn)
}
func TestNewPeer_EmptyPeers(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
if pr == nil {
t.Fatal("expected non-nil Peer")
}
if len(pr.peers) != 0 {
t.Fatalf("expected empty peers map, got %d entries", len(pr.peers))
}
}
func TestNewPeer_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 2 {
t.Fatalf("expected 2 entries, got %d", len(pr.peers))
}
if _, ok := pr.peers["model-a"]; !ok {
t.Error("expected model-a to be mapped")
}
if _, ok := pr.peers["model-b"]; !ok {
t.Error("expected model-b to be mapped")
}
if _, ok := pr.peers["model-c"]; ok {
t.Error("expected model-c to not be mapped")
}
}
func TestNewPeer_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 4 {
t.Fatalf("expected 4 entries, got %d", len(pr.peers))
}
for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} {
if _, ok := pr.peers[m]; !ok {
t.Errorf("expected %s to be mapped", m)
}
}
}
func TestNewPeer_DuplicateModel(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 1 {
t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers))
}
if _, ok := pr.peers["duplicate-model"]; !ok {
t.Error("expected duplicate-model to be mapped")
}
}
func TestPeer_ServeHTTP_Success(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
if w.Body.String() != "response from peer" {
t.Errorf("expected 'response from peer', got %q", w.Body.String())
}
}
func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if receivedAuthHeader != "Bearer secret-api-key" {
t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader)
}
}
func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "",
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if receivedAuthHeader != "" {
t.Errorf("expected no auth header, got %q", receivedAuthHeader)
}
}
func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if !strings.HasPrefix(receivedHost, "127.0.0.1:") {
t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost)
}
}
func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Header().Get("X-Accel-Buffering") != "no" {
t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering"))
}
}
func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "shutting down") {
t.Errorf("expected 'shutting down' in body, got %q", w.Body.String())
}
}
func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
started := make(chan struct{})
released := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-released
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
}()
<-started
shutdownDone := make(chan error, 1)
go func() {
shutdownDone <- pr.Shutdown(500 * time.Millisecond)
}()
// Shutdown should be waiting on inflight. If it finished already something is wrong.
time.Sleep(100 * time.Millisecond)
select {
case err := <-shutdownDone:
t.Errorf("shutdown completed before inflight finished: %v", err)
default:
}
close(released)
wg.Wait()
select {
case err := <-shutdownDone:
if err != nil {
t.Errorf("shutdown errored after inflight completed: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("shutdown did not complete after inflight finished")
}
}
func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
started := make(chan struct{})
released := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-released
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
}()
<-started
err = pr.Shutdown(100 * time.Millisecond)
if err == nil {
t.Error("expected timeout error from shutdown")
}
close(released)
wg.Wait()
}
func TestPeer_ShutdownMultiple(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err == nil {
t.Error("expected error on second shutdown")
}
if !strings.Contains(err.Error(), "already in progress") {
t.Errorf("expected 'already in progress', got %q", err.Error())
}
}
func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"extracted-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`)
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"context-model"},
},
"peer2": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"body-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestNewPeer_CustomTimeouts(t *testing.T) {
proxyURL, _ := url.Parse("http://localhost:8080")
peers := config.PeerDictionaryConfig{
"test-peer": config.PeerConfig{
Proxy: "http://localhost:8080",
ProxyURL: proxyURL,
Models: []string{"model1"},
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 300,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
member, ok := pr.peers["model1"]
if !ok {
t.Fatal("expected model1 to be mapped")
}
transport, ok := member.reverseProxy.Transport.(*http.Transport)
if !ok {
t.Fatal("expected Transport to be *http.Transport")
}
if transport.ResponseHeaderTimeout != 300*time.Second {
t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout)
}
if transport.TLSHandshakeTimeout != 15*time.Second {
t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout)
}
if transport.ExpectContinueTimeout != 2*time.Second {
t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout)
}
if transport.IdleConnTimeout != 120*time.Second {
t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout)
}
if !transport.ForceAttemptHTTP2 {
t.Error("expected ForceAttemptHTTP2 to be true")
}
}
+199
View File
@@ -0,0 +1,199 @@
package router
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/tidwall/gjson"
)
type contextkey struct {
name string
}
type ReqContextData struct {
Model string
ModelID string
Streaming bool
SendLoadingState bool
}
var (
ErrNoModelInContext = fmt.Errorf("no model in request context")
ErrNoRouterFound = fmt.Errorf("no router found for model")
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
ErrNoLocalModelFound = fmt.Errorf("local model not found")
ContextKey = &contextkey{"context"}
)
type Router interface {
// Shutdown blocks until the router has shutdown returning nil
// when the router has shutdown successfully.
//
// timeout controls how long to wait for inflight requests to finish. After
// the timeout all inflight requests will be cancelled.
Shutdown(timeout time.Duration) error
// ServeHTTP implements the http.Handler and requests coming in will
// trigger any model swapping and routing logic.
ServeHTTP(http.ResponseWriter, *http.Request)
// Handles reports whether this router can serve requests for the given model.
Handles(model string) bool
}
// LocalRouter is a Router backed by local processes whose state can be
// inspected and which can be individually stopped. Peer routers, which only
// forward to remote hosts, do not implement it.
type LocalRouter interface {
Router
// RunningModels returns the current state of every process that is not
// stopped or shut down, keyed by model ID.
RunningModels() map[string]process.ProcessState
// Unload stops the named models, or every running model when none are
// named. It blocks until each targeted process has stopped.
Unload(timeout time.Duration, models ...string)
// ProcessLogger returns the log monitor for the named model's process.
// modelID must be a real (non-alias) config key. Returns false when the
// model is not known to this router.
ProcessLogger(modelID string) (*logmon.Monitor, bool)
}
// FetchContext will attempt to get the model id from the context then
// from the model body. If it extracts the model from the body it will
// store the model in the context for downstream handlers. An error
// will be returned when model can not be fetch from either location.
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
data, ok := ReadContext(r.Context())
if ok {
return data, nil
}
if data, err := ExtractContext(r); err == nil {
realName, _ := cfg.RealModelName(data.Model)
if realName == "" {
realName = data.Model
}
data.ModelID = realName
if mc, ok := cfg.Models[realName]; ok {
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
}
*r = *r.WithContext(SetContext(r.Context(), data))
return data, nil
}
return ReqContextData{}, ErrNoModelInContext
}
func SetContext(ctx context.Context, data ReqContextData) context.Context {
return context.WithValue(ctx, ContextKey, data)
}
func ReadContext(ctx context.Context) (ReqContextData, bool) {
data, ok := ctx.Value(ContextKey).(ReqContextData)
return data, ok
}
// ExtractContext pulls the model name from an HTTP request without consuming the
// body. For GET requests it reads the "model" query parameter. For POST
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
// application/x-www-form-urlencoded bodies. The request body is always restored
// before returning so downstream handlers — including reverse proxies that
// forward raw bytes upstream — can still read it.
func ExtractContext(r *http.Request) (ReqContextData, error) {
if r.Method == http.MethodGet {
if model := r.URL.Query().Get("model"); model != "" {
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
}
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
}
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
}
defer func() {
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}()
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
model := gjson.GetBytes(bodyBytes, "model").String()
if model == "" {
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
}
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
}
// Form parsers read from r.Body, so feed them a fresh reader over the
// buffered bytes. The deferred restore above will reset r.Body again
// after parsing.
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if strings.Contains(contentType, "multipart/form-data") {
if err := r.ParseMultipartForm(32 << 20); err != nil {
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
}
} else {
if err := r.ParseForm(); err != nil {
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
}
}
if model := r.FormValue("model"); model != "" {
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
}
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
}
func SendError(w http.ResponseWriter, r *http.Request, err error) {
switch {
case errors.Is(err, ErrNoModelInContext):
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
case errors.Is(err, ErrNoPeerModelFound):
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
case errors.Is(err, ErrNoLocalModelFound):
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
case errors.Is(err, ErrNoRouterFound):
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
default:
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
}
}
// SendResponse detects what content type the client prefers and returns an error response in that format.
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
// Check Accept header for preferred response format
acceptHeader := r.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/plain") {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
return
}
if strings.Contains(acceptHeader, "text/html") {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
}
+275
View File
@@ -0,0 +1,275 @@
package router
import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"net/url"
"strings"
"testing"
)
func TestExtractContext_GET(t *testing.T) {
tests := []struct {
name string
query string
wantModel string
wantErr bool
}{
{"model present", "model=llama3", "llama3", false},
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
{"model missing", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_JSON(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantErr bool
}{
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
{"model empty string", `{"model":""}`, "", true},
{"model key missing", `{"stream":true}`, "", true},
{"invalid json", `not-json`, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_URLEncodedForm(t *testing.T) {
tests := []struct {
name string
formModel string
wantModel string
wantErr bool
}{
{"model present", "whisper-1", "whisper-1", false},
{"model missing", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
form := url.Values{}
if tt.formModel != "" {
form.Set("model", tt.formModel)
}
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_MultipartForm(t *testing.T) {
tests := []struct {
name string
formModel string
wantModel string
wantErr bool
}{
{"model present", "whisper-1", "whisper-1", false},
{"model missing", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
if tt.formModel != "" {
fw, _ := mw.CreateFormField("model")
fw.Write([]byte(tt.formModel))
}
mw.Close()
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
r.Header.Set("Content-Type", mw.FormDataContentType())
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_JSONBodyRestored(t *testing.T) {
body := `{"model":"llama3","stream":true}`
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
if _, err := ExtractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if string(remaining) != body {
t.Errorf("body not restored: want %q got %q", body, string(remaining))
}
}
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
fw, _ := mw.CreateFormField("model")
fw.Write([]byte("whisper-1"))
ff, _ := mw.CreateFormFile("file", "audio.wav")
ff.Write([]byte("fake-audio-bytes"))
mw.Close()
original := buf.Bytes()
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
r.Header.Set("Content-Type", mw.FormDataContentType())
if _, err := ExtractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if !bytes.Equal(remaining, original) {
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
}
}
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
body := "model=whisper-1&extra=value"
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if _, err := ExtractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if string(remaining) != body {
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
}
}
func TestSetContext(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
data, ok := ctx.Value(ContextKey).(ReqContextData)
if !ok {
t.Fatalf("ContextKey not set or wrong type")
}
if data.Model != "llama3" {
t.Errorf("want %q got %q", "llama3", data.Model)
}
if data.ModelID != "llama3" {
t.Errorf("want %q got %q", "llama3", data.ModelID)
}
}
func TestSetContext_WithAlias(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
data, _ := ctx.Value(ContextKey).(ReqContextData)
if data.Model != "llama" {
t.Errorf("want requested %q got %q", "llama", data.Model)
}
if data.ModelID != "llama3" {
t.Errorf("want real %q got %q", "llama3", data.ModelID)
}
}
func TestSetContext_DoesNotMutateParent(t *testing.T) {
parent := context.Background()
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
if v := parent.Value(ContextKey); v != nil {
t.Errorf("parent context was mutated: %v", v)
}
}
func TestReadContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
wantReq string
wantReal string
wantBool bool
}{
{
name: "model present, same name",
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
wantReq: "llama3",
wantReal: "llama3",
wantBool: true,
},
{
name: "model present, aliased",
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
wantReq: "llama",
wantReal: "llama3",
wantBool: true,
},
{
name: "model absent",
ctx: context.Background(),
wantReq: "",
wantReal: "",
wantBool: false,
},
{
name: "model is empty string",
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
wantReq: "",
wantReal: "",
wantBool: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotData, ok := ReadContext(tt.ctx)
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
}
})
}
}