Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 519c3a4d22 |
+1
-1
@@ -84,7 +84,7 @@ func main() {
|
|||||||
case newManager := <-reloadChan:
|
case newManager := <-reloadChan:
|
||||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||||
currentManager.StopProcesses()
|
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||||
// Now do a full shutdown to clear the process map
|
// Now do a full shutdown to clear the process map
|
||||||
currentManager.Shutdown()
|
currentManager.Shutdown()
|
||||||
currentManager = newManager
|
currentManager = newManager
|
||||||
|
|||||||
+21
-1
@@ -30,6 +30,13 @@ const (
|
|||||||
StateShutdown ProcessState = ProcessState("shutdown")
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StopStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StopImmediately StopStrategy = iota
|
||||||
|
StopWaitForInflightRequest
|
||||||
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
@@ -313,13 +320,25 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop will wait for inflight requests to complete before stopping the process.
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
|
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
p.StopImmediately()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||||
|
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||||
|
func (p *Process) StopImmediately() {
|
||||||
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
@@ -338,7 +357,8 @@ func (p *Process) Stop() {
|
|||||||
|
|
||||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
// of time for any inflight requests to complete before shutting down. If the Process
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
// is in the state of starting, it will cancel it and shut it down
|
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||||
|
// the StateShutdown state, it can not be started again.
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
p.shutdownCancel()
|
p.shutdownCancel()
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
|
|||||||
@@ -372,3 +372,24 @@ func TestProcess_ConcurrencyLimit(t *testing.T) {
|
|||||||
process.ProxyRequest(w, denied)
|
process.ProxyRequest(w, denied)
|
||||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopImmediately(t *testing.T) {
|
||||||
|
expectedMessage := "test_stop_immediate"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
}()
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcesses() {
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
pg.stopProcesses()
|
|
||||||
}
|
|
||||||
|
|
||||||
// stopProcesses stops all processes in the group
|
|
||||||
func (pg *ProcessGroup) stopProcesses() {
|
|
||||||
if len(pg.processes) == 0 {
|
if len(pg.processes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(process *Process) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Stop()
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
}(process)
|
}(process)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses()
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2"}
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses()
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model3", "model4"}
|
tests := []string{"model3", "model4"}
|
||||||
|
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// This is the public method safe for concurrent calls.
|
// This is the public method safe for concurrent calls.
|
||||||
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
||||||
// a complete shutdown, allowing for process replacement without full termination.
|
// a complete shutdown, allowing for process replacement without full termination.
|
||||||
func (pm *ProxyManager) StopProcesses() {
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
@@ -218,7 +218,7 @@ func (pm *ProxyManager) StopProcesses() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(processGroup *ProcessGroup) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
processGroup.stopProcesses()
|
processGroup.StopProcesses(strategy)
|
||||||
}(processGroup)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||||
for groupId, otherGroup := range pm.processGroups {
|
for groupId, otherGroup := range pm.processGroups {
|
||||||
if groupId != processGroup.id && !otherGroup.persistent {
|
if groupId != processGroup.id && !otherGroup.persistent {
|
||||||
otherGroup.StopProcesses()
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -504,7 +504,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses()
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+10
-10
@@ -27,7 +27,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
@@ -63,7 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2"}
|
||||||
for _, requestedModel := range tests {
|
for _, requestedModel := range tests {
|
||||||
@@ -105,7 +105,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
// make requests to load all models, loading model1 should not affect model2
|
// make requests to load all models, loading model1 should not affect model2
|
||||||
tests := []string{"model2", "model1"}
|
tests := []string{"model2", "model1"}
|
||||||
@@ -141,7 +141,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
results := map[string]string{}
|
results := map[string]string{}
|
||||||
|
|
||||||
@@ -352,7 +352,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
// Create proxy once for all tests
|
// Create proxy once for all tests
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
t.Run("no models loaded", func(t *testing.T) {
|
t.Run("no models loaded", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
@@ -407,7 +407,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
// Create a buffer with multipart form data
|
// Create a buffer with multipart form data
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
@@ -461,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
requestedModel := "model1"
|
requestedModel := "model1"
|
||||||
|
|
||||||
@@ -557,7 +557,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||||
for k, v := range tt.requestHeaders {
|
for k, v := range tt.requestHeaders {
|
||||||
@@ -586,7 +586,7 @@ func TestProxyManager_Upstream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
@@ -604,7 +604,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
|||||||
Reference in New Issue
Block a user