Compare commits

...

1 Commits

Author SHA1 Message Date
Benson Wong 519c3a4d22 Change /unload to not wait for inflight requests (#125)
Sometimes upstreams can accept HTTP but never respond causing requests
to build up waiting for a response. This can block Process.Stop() as
that waits for inflight requests to finish. This change refactors the
code to not wait when attempting to shutdown the process.
2025-05-13 11:39:19 -07:00
7 changed files with 66 additions and 24 deletions
+1 -1
View File
@@ -84,7 +84,7 @@ func main() {
case newManager := <-reloadChan: case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...") log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests) // Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses() currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
// Now do a full shutdown to clear the process map // Now do a full shutdown to clear the process map
currentManager.Shutdown() currentManager.Shutdown()
currentManager = newManager currentManager = newManager
+21 -1
View File
@@ -30,6 +30,13 @@ const (
StateShutdown ProcessState = ProcessState("shutdown") StateShutdown ProcessState = ProcessState("shutdown")
) )
type StopStrategy int
const (
StopImmediately StopStrategy = iota
StopWaitForInflightRequest
)
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
@@ -313,13 +320,25 @@ func (p *Process) start() error {
} }
} }
// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() { func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) { if !isValidTransition(p.CurrentState(), StateStopping) {
return return
} }
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.StopImmediately()
}
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.proxyLogger.Debugf("<%s> Stopping process", p.ID) p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
@@ -338,7 +357,8 @@ func (p *Process) Stop() {
// Shutdown is called when llama-swap is shutting down. It will give a little bit // Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down // is in the state of starting, it will cancel it and shut it down. Once a process is in
// the StateShutdown state, it can not be started again.
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
p.shutdownCancel() p.shutdownCancel()
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
+21
View File
@@ -372,3 +372,24 @@ func TestProcess_ConcurrencyLimit(t *testing.T) {
process.ProxyRequest(w, denied) process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code) assert.Equal(t, http.StatusTooManyRequests, w.Code)
} }
func TestProcess_StopImmediately(t *testing.T) {
expectedMessage := "test_stop_immediate"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
+7 -6
View File
@@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName) return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
} }
func (pg *ProcessGroup) StopProcesses() { func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock() pg.Lock()
defer pg.Unlock() defer pg.Unlock()
pg.stopProcesses()
}
// stopProcesses stops all processes in the group
func (pg *ProcessGroup) stopProcesses() {
if len(pg.processes) == 0 { if len(pg.processes) == 0 {
return return
} }
@@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
wg.Add(1) wg.Add(1)
go func(process *Process) { go func(process *Process) {
defer wg.Done() defer wg.Done()
process.Stop() switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
}(process) }(process)
} }
wg.Wait() wg.Wait()
+2 -2
View File
@@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses() defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"} tests := []string{"model1", "model2"}
@@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses() defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model3", "model4"} tests := []string{"model3", "model4"}
+4 -4
View File
@@ -208,7 +208,7 @@ func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// This is the public method safe for concurrent calls. // This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform // Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination. // a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses() { func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
@@ -218,7 +218,7 @@ func (pm *ProxyManager) StopProcesses() {
wg.Add(1) wg.Add(1)
go func(processGroup *ProcessGroup) { go func(processGroup *ProcessGroup) {
defer wg.Done() defer wg.Done()
processGroup.stopProcesses() processGroup.StopProcesses(strategy)
}(processGroup) }(processGroup)
} }
@@ -260,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups { for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent { if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses() otherGroup.StopProcesses(StopWaitForInflightRequest)
} }
} }
} }
@@ -504,7 +504,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
} }
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses() pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK") c.String(http.StatusOK, "OK")
} }
+10 -10
View File
@@ -27,7 +27,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
for _, modelName := range []string{"model1", "model2"} { for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
@@ -63,7 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"} tests := []string{"model1", "model2"}
for _, requestedModel := range tests { for _, requestedModel := range tests {
@@ -105,7 +105,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
// make requests to load all models, loading model1 should not affect model2 // make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"} tests := []string{"model2", "model1"}
@@ -141,7 +141,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
results := map[string]string{} results := map[string]string{}
@@ -352,7 +352,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Create proxy once for all tests // Create proxy once for all tests
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("no models loaded", func(t *testing.T) { t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil) req := httptest.NewRequest("GET", "/running", nil)
@@ -407,7 +407,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
// Create a buffer with multipart form data // Create a buffer with multipart form data
var b bytes.Buffer var b bytes.Buffer
@@ -461,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
requestedModel := "model1" requestedModel := "model1"
@@ -557,7 +557,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders { for k, v := range tt.requestHeaders {
@@ -586,7 +586,7 @@ func TestProxyManager_Upstream(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil) req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req) proxy.ServeHTTP(rec, req)
@@ -604,7 +604,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))