diff --git a/proxy/matrix.go b/proxy/matrix.go index f3f7227a..70fc7738 100644 --- a/proxy/matrix.go +++ b/proxy/matrix.go @@ -147,6 +147,20 @@ type Matrix struct { config config.Config proxyLogger *LogMonitor upstreamLogger *LogMonitor + + // inflight tracks ProxyRequest calls that have released m.Lock but may + // not yet have incremented Process.inFlightRequests. A concurrent + // request that needs to evict models waits for inflight to drain under + // m.Lock before stopping anything. Without this, a request that + // released m.Lock but has not yet reached Process.inFlightRequests.Add(1) + // races with Stop()'s Wait() and can be killed mid-request. + inflight sync.WaitGroup + + // testDelayFastPath is a test-only hook invoked in the no-eviction path + // after m.Lock is released but before the request is dispatched to + // Process.ProxyRequest. Tests use it to park a request at the exact + // race window to deterministically reproduce the race. + testDelayFastPath func() } // NewMatrix creates a Matrix from config. It creates a Process for every @@ -197,6 +211,13 @@ func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Req // Evict models that need to be stopped if len(result.Evict) > 0 { + // Wait for any in-flight ProxyRequest calls to register on their + // Process before stopping anything. Without this, a request that + // released m.Lock but has not yet incremented + // Process.inFlightRequests races with Stop() and can be killed + // mid-request. + m.inflight.Wait() + var wg sync.WaitGroup for _, evictModel := range result.Evict { if p, exists := m.processes[evictModel]; exists { @@ -209,8 +230,18 @@ func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Req } wg.Wait() } + + // Register this request in inflight before releasing m.Lock so a + // concurrent eviction will wait for it to complete. + m.inflight.Add(1) + defer m.inflight.Done() + isFastPath := len(result.Evict) == 0 m.Unlock() + if isFastPath && m.testDelayFastPath != nil { + m.testDelayFastPath() + } + // Proxy the request (Process handles on-demand start) process.ProxyRequest(w, r) return nil diff --git a/proxy/matrix_test.go b/proxy/matrix_test.go index 81d5a1a8..8b921379 100644 --- a/proxy/matrix_test.go +++ b/proxy/matrix_test.go @@ -1,7 +1,11 @@ package proxy import ( + "net/http" + "net/http/httptest" + "runtime" "testing" + "time" "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" @@ -169,6 +173,124 @@ func TestMatrixSolver_NothingRunning(t *testing.T) { assert.Equal(t, []string{"g", "v"}, result.TargetSet) } +// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction +// cannot stop a process while an in-flight ProxyRequest for that process is +// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without +// matrix-level inflight tracking, the eviction's Stop() races with the +// pending request and kills it mid-start. +func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) { + cfg := config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + ExpandedSets: []config.ExpandedSet{ + {SetName: "s1", Models: []string{"model1"}}, + {SetName: "s2", Models: []string{"model2"}}, + }, + Matrix: &config.MatrixConfig{}, + } + + m := NewMatrix(cfg, testLogger, testLogger) + defer m.StopProcesses(StopImmediately) + + // Bypass real subprocesses so the test is fast and deterministic. + m.processes["model1"].testHandler = newTestHandler("model1") + m.processes["model2"].testHandler = newTestHandler("model2") + + // Prime: run a request through model1 so it reaches StateReady and + // subsequent requests take the no-eviction path. + primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil) + primeW := httptest.NewRecorder() + require.NoError(t, m.ProxyRequest("model1", primeW, primeReq)) + require.Equal(t, http.StatusOK, primeW.Code) + require.Equal(t, StateReady, m.processes["model1"].CurrentState()) + require.Equal(t, StateStopped, m.processes["model2"].CurrentState()) + + // Install fast-path hook that signals arrival and waits for release. + // This parks R2 at the race window — after m.Lock is released but + // before Process.inFlightRequests.Add(1). + r2Reached := make(chan struct{}) + r2Release := make(chan struct{}) + m.testDelayFastPath = func() { + close(r2Reached) + <-r2Release + } + + // R2: no-eviction request for model1. Will pause at the hook. + r2Done := make(chan struct{}) + w2 := httptest.NewRecorder() + go func() { + defer close(r2Done) + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + assert.NoError(t, m.ProxyRequest("model1", w2, req)) + }() + + // Deterministically wait for R2 to reach the race window. + <-r2Reached + + // R3: request for model2 which requires evicting model1. Must wait for + // R2 to finish before touching model1. + r3Done := make(chan struct{}) + w3 := httptest.NewRecorder() + go func() { + defer close(r3Done) + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + assert.NoError(t, m.ProxyRequest("model2", w3, req)) + }() + + // Spin until R3 has acquired m.Lock and entered the eviction path. In + // the fixed code, R3 then blocks on m.inflight.Wait() while still + // holding the lock, so TryLock keeps failing. + for m.TryLock() { + m.Unlock() + runtime.Gosched() + } + + // Bounded poll: give R3 a chance to demonstrate the bug by mutating + // state. In the fixed code R3 is blocked and nothing changes; in the + // buggy code R3 will Stop() model1 and start model2 within microseconds. + deadline := time.Now().Add(100 * time.Millisecond) + for time.Now().Before(deadline) { + if m.processes["model1"].CurrentState() != StateReady || + m.processes["model2"].CurrentState() != StateStopped { + break + } + done := false + select { + case <-r3Done: + done = true + default: + } + if done { + break + } + runtime.Gosched() + } + + // Invariant: R3 must be blocked while R2 is still in flight. + select { + case <-r3Done: + t.Fatal("eviction completed while in-flight request was still pending — race not prevented") + default: + } + assert.Equal(t, StateReady, m.processes["model1"].CurrentState(), + "model1 must stay Ready while an in-flight request is pending") + assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(), + "model2 must not be started until R2 finishes and model1 is evicted") + + // Release R2 and let both requests finish. + close(r2Release) + <-r2Done + <-r3Done + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Contains(t, w2.Body.String(), "model1") + assert.Equal(t, http.StatusOK, w3.Code) + assert.Contains(t, w3.Body.String(), "model2") +} + func TestMatrixSolver_FullScenario(t *testing.T) { // Simulates the example config: // standard: [g,v], [q,v], [m,v] diff --git a/proxy/processgroup.go b/proxy/processgroup.go index b401d8a6..c3055e0c 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -24,6 +24,22 @@ type ProcessGroup struct { // map of current processes processes map[string]*Process lastUsedProcess string + + // inflight tracks fast-path requests (requests for the already-selected + // model in a swap group). Fast-path requests Add(1) while holding pg.Lock + // and Done() on completion; a concurrent swap request calls inflight.Wait() + // under pg.Lock before stopping the current process. Without this tracking, + // a fast-path request that has released pg.Lock but has not yet called + // Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be + // killed mid-request. + inflight sync.WaitGroup + + // testDelayFastPath is a test-only hook that, when non-nil, is invoked in + // the fast path after pg.Lock is released but before the request is + // dispatched to Process.ProxyRequest. Tests use it to park a fast-path + // request at the exact race window to deterministically reproduce the + // fast-path vs swap race. + testDelayFastPath func() } func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { @@ -64,6 +80,13 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, pg.Lock() if pg.lastUsedProcess != modelID { + // Wait for in-flight fast-path requests to drain before stopping + // the previous process. Without this, a fast-path request that has + // released pg.Lock but has not yet incremented + // Process.inFlightRequests races with Stop() and can be killed + // mid-request. + pg.inflight.Wait() + // is there something already running? if pg.lastUsedProcess != "" { pg.processes[pg.lastUsedProcess].Stop() @@ -78,7 +101,16 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, pg.Unlock() return nil } + + // Fast path: register this request in inflight before releasing + // pg.Lock so a concurrent swap will wait for it to complete. + pg.inflight.Add(1) + defer pg.inflight.Done() pg.Unlock() + + if pg.testDelayFastPath != nil { + pg.testDelayFastPath() + } } pg.processes[modelID].ProxyRequest(writer, request) @@ -123,6 +155,10 @@ func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) { pg.Lock() defer pg.Unlock() + if strategy != StopImmediately { + pg.inflight.Wait() + } + if len(pg.processes) == 0 { return } diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index 6b90f443..d261baeb 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -4,11 +4,14 @@ import ( "bytes" "net/http" "net/http/httptest" + "runtime" "sync" "testing" + "time" "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ @@ -95,6 +98,229 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { wg.Wait() } +// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap +// request cannot stop the current process while a fast-path request (for the +// already-selected model) is in flight. Without ProcessGroup-level inflight +// tracking, a fast-path request that has released pg.Lock but has not yet +// incremented Process.inFlightRequests races with Stop()'s Wait() and the +// process is killed mid-request. +func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) { + cfg := config.AddDefaultGroupToConfig(config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + Groups: map[string]config.GroupConfig{ + "G1": { + Swap: true, + Members: []string{"model1", "model2"}, + }, + }, + }) + + pg := NewProcessGroup("G1", cfg, testLogger, testLogger) + defer pg.StopProcesses(StopImmediately) + + // Bypass real subprocesses so the test is fast and deterministic. + pg.processes["model1"].testHandler = newTestHandler("model1") + pg.processes["model2"].testHandler = newTestHandler("model2") + + // Prime: run a request through model1 via the swap path so that + // lastUsedProcess == "model1" and subsequent model1 requests take the + // fast path. + primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil) + primeW := httptest.NewRecorder() + require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq)) + require.Equal(t, http.StatusOK, primeW.Code) + require.Equal(t, StateReady, pg.processes["model1"].CurrentState()) + require.Equal(t, StateStopped, pg.processes["model2"].CurrentState()) + + // Fast-path hook: signal arrival at the race window, then wait for + // release. This parks R2 deterministically at the point where pg.Lock + // has been released but Process.inFlightRequests has not yet been + // incremented — the exact window the race exploits. + r2Reached := make(chan struct{}) + r2Release := make(chan struct{}) + pg.testDelayFastPath = func() { + close(r2Reached) + <-r2Release + } + + // R2: fast-path request for model1. Will pause at the test hook. + r2Done := make(chan struct{}) + w2 := httptest.NewRecorder() + go func() { + defer close(r2Done) + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + assert.NoError(t, pg.ProxyRequest("model1", w2, req)) + }() + + // Deterministically wait for R2 to reach the race window. + <-r2Reached + + // R3: swap request for model2. Must wait for R2 to finish before touching + // model1, otherwise model1 gets killed mid-request. + r3Done := make(chan struct{}) + w3 := httptest.NewRecorder() + go func() { + defer close(r3Done) + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + assert.NoError(t, pg.ProxyRequest("model2", w3, req)) + }() + + // Spin until R3 has acquired pg.Lock and entered the swap critical + // section. In the fixed code, R3 then blocks on pg.inflight.Wait() while + // still holding the lock, so TryLock keeps failing. + for pg.TryLock() { + pg.Unlock() + runtime.Gosched() + } + + // Bounded poll: give R3 a chance to demonstrate the bug by mutating + // state. In the fixed code, R3 is blocked on pg.inflight.Wait() and + // nothing changes, so we wait the full window. In the buggy code, R3 + // will Stop() model1 and start serving via model2 within microseconds — + // we exit early once the mutation is observable. + deadline := time.Now().Add(100 * time.Millisecond) + for time.Now().Before(deadline) { + if pg.processes["model1"].CurrentState() != StateReady || + pg.processes["model2"].CurrentState() != StateStopped { + break + } + done := false + select { + case <-r3Done: + done = true + default: + } + if done { + break + } + runtime.Gosched() + } + + // Invariant: R3 must be blocked while R2 is still in flight. + select { + case <-r3Done: + t.Fatal("swap completed while fast-path request was still in flight — race not prevented") + default: + } + assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(), + "model1 must stay Ready while a fast-path request is in flight") + assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(), + "model2 must not be started until R2 finishes and model1 is swapped out") + + // Release R2 and let both requests finish. + close(r2Release) + <-r2Done + <-r3Done + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Contains(t, w2.Body.String(), "model1") + assert.Equal(t, http.StatusOK, w3.Code) + assert.Contains(t, w3.Body.String(), "model2") +} + +// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses +// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a +// process while a fast-path ProxyRequest is in the [pg.Unlock, +// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in +// StopProcesses, the external caller bypasses the inflight guard and kills the +// process mid-request. +func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) { + cfg := config.AddDefaultGroupToConfig(config.Config{ + HealthCheckTimeout: 15, + Models: map[string]config.ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + Groups: map[string]config.GroupConfig{ + "G1": { + Swap: true, + Members: []string{"model1", "model2"}, + }, + }, + }) + + pg := NewProcessGroup("G1", cfg, testLogger, testLogger) + defer pg.StopProcesses(StopImmediately) + + pg.processes["model1"].testHandler = newTestHandler("model1") + pg.processes["model2"].testHandler = newTestHandler("model2") + + // Prime: model1 is active so subsequent model1 requests take the fast path. + primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil) + primeW := httptest.NewRecorder() + require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq)) + require.Equal(t, http.StatusOK, primeW.Code) + require.Equal(t, StateReady, pg.processes["model1"].CurrentState()) + + // Park a fast-path request at the race window. + r2Reached := make(chan struct{}) + r2Release := make(chan struct{}) + pg.testDelayFastPath = func() { + close(r2Reached) + <-r2Release + } + + r2Done := make(chan struct{}) + w2 := httptest.NewRecorder() + go func() { + defer close(r2Done) + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + assert.NoError(t, pg.ProxyRequest("model1", w2, req)) + }() + + <-r2Reached + + // Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping + // the group while a fast-path request is in flight. + r3Done := make(chan struct{}) + go func() { + defer close(r3Done) + pg.StopProcesses(StopWaitForInflightRequest) + }() + + // Spin until StopProcesses has acquired pg.Lock. + for pg.TryLock() { + pg.Unlock() + runtime.Gosched() + } + + // Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait() + // and model1 stays Ready. In the buggy code it proceeds immediately and + // kills model1. + deadline := time.Now().Add(100 * time.Millisecond) + for time.Now().Before(deadline) { + if pg.processes["model1"].CurrentState() != StateReady { + break + } + select { + case <-r3Done: + goto done + default: + } + runtime.Gosched() + } +done: + + select { + case <-r3Done: + t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented") + default: + } + assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(), + "model1 must stay Ready while a fast-path request is in flight") + + close(r2Release) + <-r2Done + <-r3Done + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Contains(t, w2.Body.String(), "model1") +} + func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) defer pg.StopProcesses(StopWaitForInflightRequest)