Compare commits

...

6 Commits

Author SHA1 Message Date
Benson Wong 09bdd86b54 Improve shutdown behaviour (#47) (#49)
Introduce `Process.Shutdown()` and `ProxyManager.Shutdown()`. These two function required a lot of internal process state management refactoring. A key benefit is that `Process.start()` is now interruptable. When `Shutdown()` is called it will break the long health check loop. 

State management within Process is also improved. Added `starting`, `stopping` and `shutdown` states. Additionally, introduced a simple finite state machine to manage transitions.
2025-02-05 17:19:59 -08:00
Benson Wong 85cd74a51c Improve process start and stop reliability (#38)
Refactor Process.start()/Stop() logic (#38)
- remove cmd.Wait() call in start(). This seems to conflict with the one
  in .Stop(). Removing it eliminated no child errors
- eliminate goroutines in .start() as it no longer required
2025-02-03 11:50:38 -08:00
Benson Wong 314d2f2212 remove cmd_stop configuration and functionality from PR #40 (#44)
* remove cmd_stop functionality from #40
2025-01-31 12:42:44 -08:00
Benson Wong fad25f3e11 Use client request context in proxy request (#43)
Canceled or closed HTTP requests from clients will also stop the proxied
HTTP requests to upstreamed servers.
2025-01-31 10:21:49 -08:00
Benson Wong 2c3e3e27f7 Support OPTIONS requests (#42)
Add middleware that responds with permissive OPTIONS headers
for all request paths.
2025-01-31 10:09:07 -08:00
Benson Wong baeb0c4e7f Add cmd_stop configuration to better support docker (#35)
Add `cmd_stop` to model configuration to run a command instead of sending a SIGTERM to shutdown a process before swapping.
2025-01-30 16:59:57 -08:00
7 changed files with 456 additions and 145 deletions
+11 -1
View File
@@ -5,7 +5,7 @@
# Introduction
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
@@ -30,6 +30,7 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
- `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- ✅ Multiple GPU support
- ✅ Docker and Podman support
- ✅ Run multiple models at once with `profiles`
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
@@ -89,6 +90,15 @@ models:
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
unlisted: true
# Docker Support (v26.1.4+ required!)
"docker-llama":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles make it easy to managing multi model (and gpu) configurations.
#
# Tips:
+8
View File
@@ -53,6 +53,14 @@ models:
--ctx-size 8192
--reranking
# Docker Support (v26.1.4+ required!)
"dockertest":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
"simple":
# example of setting environment variables
+12
View File
@@ -4,6 +4,8 @@ import (
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy"
@@ -39,6 +41,16 @@ func main() {
}
proxyManager := proxy.New(config)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
fmt.Println("Shutting down llama-swap")
proxyManager.Shutdown()
os.Exit(0)
}()
fmt.Println("llama-swap listening on " + *listenStr)
if err := proxyManager.Run(*listenStr); err != nil {
fmt.Printf("Server error: %v\n", err)
+221 -141
View File
@@ -17,14 +17,19 @@ import (
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateReady ProcessState = ProcessState("ready")
StateFailed ProcessState = ProcessState("failed")
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// failed a health check on start and will not be recovered
StateFailed ProcessState = ProcessState("failed")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type Process struct {
sync.Mutex
ID string
config ModelConfig
cmd *exec.Cmd
@@ -37,9 +42,17 @@ type Process struct {
state ProcessState
inFlightRequests sync.WaitGroup
// used to block on multiple start() calls
waitStarting sync.WaitGroup
// for managing shutdown state
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
return &Process{
ID: ID,
config: config,
@@ -47,22 +60,88 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
}
}
// start the process and returns when it is ready
func (p *Process) setState(newState ProcessState) error {
// enforce valid state transitions
invalidTransition := false
if p.state == StateStopped {
// stopped -> starting
if newState != StateStarting {
invalidTransition = true
}
} else if p.state == StateStarting {
// starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateReady {
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
}
if invalidTransition {
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
}
p.state = newState
return nil
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
func (p *Process) start() error {
if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing")
}
// wait for the other start() to complete
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state == StateReady {
return nil
if err := p.setState(StateStarting); err != nil {
return err
}
if p.state == StateFailed {
return fmt.Errorf("process is in a failed state and can not be restarted")
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
@@ -77,7 +156,8 @@ func (p *Process) start() error {
err = p.cmd.Start()
if err != nil {
return err
p.setState(StateFailed)
return fmt.Errorf("start() failed: %v", err)
}
// One of three things can happen at this stage:
@@ -86,35 +166,56 @@ func (p *Process) start() error {
// 3. The health check passes
//
// only in the third case will the process be considered Ready to accept
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
defer cancelHealthCheck(nil) // clean up
cmdWaitChan := make(chan error, 1)
healthCheckChan := make(chan error, 1)
<-time.After(250 * time.Millisecond) // give process a bit of time to start
go func() {
// possible cmd exits early
cmdWaitChan <- p.cmd.Wait()
}()
checkStartTime := time.Now()
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
go func() {
<-time.After(250 * time.Millisecond) // give process a bit of time to start
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
}()
select {
case err := <-cmdWaitChan:
p.state = StateFailed
if err != nil {
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
} else {
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
cancelHealthCheck(err)
return err
case err := <-healthCheckChan:
proxyTo := p.config.Proxy
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
p.state = StateFailed
return err
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
}
checkDeadline, cancelHealthCheck := context.WithDeadline(
context.Background(),
checkStartTime.Add(maxDuration),
)
defer cancelHealthCheck()
// Health check loop
loop:
for {
select {
case <-checkDeadline.Done():
p.setState(StateFailed)
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
}
}
}
<-time.After(time.Second)
}
}
@@ -141,162 +242,141 @@ func (p *Process) start() error {
}()
}
p.state = StateReady
return nil
return p.setState(StateReady)
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != StateReady {
// calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
return
}
if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
p.state = StateStopped
return
// stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second)
if err := p.setState(StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
}
}
// Pretty sure this stopping code needs some work for windows and
// will be a source of pain in the future.
// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel()
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) {
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
panic("this should not happen, cmd or cmd.Process is nil")
}
p.cmd.Process.Signal(syscall.SIGTERM)
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "XXX Process for %s timed out waiting to stop, sending SIGKILL to PID: %d\n", p.ID, p.cmd.Process.Pid)
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill()
p.cmd.Wait()
case err := <-sigtermNormal:
if err != nil {
if err.Error() != "wait: no child processes" {
// possible that simple-responder for testing is just not
// existing right, so suppress those errors.
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
}
}
}
p.state = StateStopped
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req)
ttl := (maxDuration - time.Since(startTime)).Seconds()
if err != nil {
// check if the context was cancelled
select {
case <-ctx.Done():
err := context.Cause(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
return err
if errno, ok := err.(syscall.Errno); ok {
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
}
default:
}
// wait a bit longer for TCP connection issues
if strings.Contains(err.Error(), "connection refused") {
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
time.Sleep(5 * time.Second)
} else {
time.Sleep(time.Second)
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
}
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
}
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// got a response but it was not an OK
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code: %d", resp.StatusCode)
}
return nil
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Add(1)
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
return
}
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
// start the process on demand
if p.CurrentState() != StateReady {
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError)
http.Error(w, errstr, http.StatusBadGateway)
return
}
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
+115 -2
View File
@@ -48,6 +48,33 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
}
}
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(reqID int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
}(i)
}
wg.Wait()
assert.Equal(t, StateReady, process.CurrentState())
}
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
@@ -58,13 +85,17 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
}
process := NewProcess("broken", 1, config, NewLogMonitor())
defer process.Stop()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "unable to start process")
w = httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
}
func TestProcess_UnloadAfterTTL(t *testing.T) {
@@ -190,3 +221,85 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
assert.Equal(t, key, result)
}
}
func TestSetState(t *testing.T) {
tests := []struct {
name string
currentState ProcessState
newState ProcessState
expectedError error
expectedResult ProcessState
}{
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := &Process{
state: test.currentState,
}
err := p.setState(test.newState)
if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil {
t.Errorf("Expected error: %v, but got none", test.expectedError)
} else if err != nil && test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
}
}
if p.state != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
}
})
}
}
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping long shutdown test")
}
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// start a goroutine to simulate a shutdown
var wg sync.WaitGroup
go func() {
defer wg.Done()
<-time.After(time.Second * 2)
process.Shutdown()
}()
wg.Add(1)
// start the process, this is a blocking call
err := process.start()
wg.Wait()
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
+39 -1
View File
@@ -69,6 +69,19 @@ func New(config *Config) *ProxyManager {
})
}
// see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
@@ -143,13 +156,38 @@ func (pm *ProxyManager) stopProcesses() {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
process.Stop()
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
pm.currentProcesses = make(map[string]*Process)
}
// Shutdown is called to shutdown all upstream processes
// when llama-swap is shutting down.
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
// shutdown process in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id, modelConfig := range pm.config.Models {
+50
View File
@@ -254,3 +254,53 @@ func TestProxyManager_ProfileNonMember(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/"
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
model2Config.Proxy = "http://localhost:10002/"
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/"
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": model3Config,
},
}
proxy := New(config)
// Start all the processes
var wg sync.WaitGroup
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
wg.Add(1)
go func(modelName string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
// send a request to trigger the proxy to load
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName)
}
go func() {
<-time.After(time.Second)
proxy.Shutdown()
}()
wg.Wait()
}