Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef |
@@ -5,7 +5,7 @@
|
||||
# Introduction
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
|
||||
|
||||
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
|
||||
|
||||
@@ -30,6 +30,7 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
@@ -89,6 +90,15 @@ models:
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
|
||||
@@ -53,6 +53,14 @@ models:
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"dockertest":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
@@ -39,6 +41,16 @@ func main() {
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("Shutting down llama-swap")
|
||||
proxyManager.Shutdown()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
|
||||
+224
-142
@@ -17,14 +17,19 @@ import (
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// failed a health check on start and will not be recovered
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
sync.Mutex
|
||||
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
@@ -37,9 +42,17 @@ type Process struct {
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing shutdown state
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
@@ -47,22 +60,88 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
state: StateStopped,
|
||||
shutdownCtx: ctx,
|
||||
shutdownCancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// start the process and returns when it is ready
|
||||
func (p *Process) setState(newState ProcessState) error {
|
||||
// enforce valid state transitions
|
||||
invalidTransition := false
|
||||
if p.state == StateStopped {
|
||||
// stopped -> starting
|
||||
if newState != StateStarting {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateStarting {
|
||||
// starting -> ready | failed | stopping
|
||||
if newState != StateReady && newState != StateFailed && newState != StateStopping {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateReady {
|
||||
// ready -> stopping
|
||||
if newState != StateStopping {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateStopping {
|
||||
// stopping -> stopped | shutdown
|
||||
if newState != StateStopped && newState != StateShutdown {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateFailed || p.state == StateShutdown {
|
||||
invalidTransition = true
|
||||
}
|
||||
|
||||
if invalidTransition {
|
||||
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
|
||||
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
// wait for the other start() to complete
|
||||
curState := p.CurrentState()
|
||||
|
||||
if curState == StateReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
|
||||
if state := p.CurrentState(); state != StateReady {
|
||||
return fmt.Errorf("start() failed current state: %v", state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state == StateReady {
|
||||
return nil
|
||||
if err := p.setState(StateStarting); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.state == StateFailed {
|
||||
return fmt.Errorf("process is in a failed state and can not be restarted")
|
||||
}
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
@@ -77,7 +156,8 @@ func (p *Process) start() error {
|
||||
err = p.cmd.Start()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
p.setState(StateFailed)
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
}
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
@@ -86,35 +166,56 @@ func (p *Process) start() error {
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
||||
defer cancelHealthCheck(nil) // clean up
|
||||
cmdWaitChan := make(chan error, 1)
|
||||
healthCheckChan := make(chan error, 1)
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
go func() {
|
||||
// possible cmd exits early
|
||||
cmdWaitChan <- p.cmd.Wait()
|
||||
}()
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
go func() {
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdWaitChan:
|
||||
p.state = StateFailed
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
||||
} else {
|
||||
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
cancelHealthCheck(err)
|
||||
return err
|
||||
case err := <-healthCheckChan:
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
p.state = StateFailed
|
||||
return err
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
||||
context.Background(),
|
||||
checkStartTime.Add(maxDuration),
|
||||
)
|
||||
defer cancelHealthCheck()
|
||||
|
||||
// Health check loop
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-checkDeadline.Done():
|
||||
p.setState(StateFailed)
|
||||
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
|
||||
case <-p.shutdownCtx.Done():
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
default:
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
cancelHealthCheck()
|
||||
break loop
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
endTime, _ := checkDeadline.Deadline()
|
||||
ttl := time.Until(endTime)
|
||||
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
<-time.After(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,166 +236,147 @@ func (p *Process) start() error {
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
p.state = StateReady
|
||||
return nil
|
||||
return p.setState(StateReady)
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != StateReady {
|
||||
// calling Stop() when state is invalid is a no-op
|
||||
if err := p.setState(StateStopping); err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
// this situation should never happen... but if it does just update the state
|
||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
||||
p.state = StateStopped
|
||||
return
|
||||
// stop the process with a graceful exit timeout
|
||||
p.stopCommand(5 * time.Second)
|
||||
|
||||
if err := p.setState(StateStopped); err != nil {
|
||||
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Pretty sure this stopping code needs some work for windows and
|
||||
// will be a source of pain in the future.
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down
|
||||
func (p *Process) Shutdown() {
|
||||
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
|
||||
p.shutdownCancel()
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.setState(StateStopping)
|
||||
|
||||
// 5 seconds to stop the process
|
||||
p.stopCommand(5 * time.Second)
|
||||
if err := p.setState(StateShutdown); err != nil {
|
||||
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
|
||||
}
|
||||
p.setState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||
defer cancelTimeout()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
panic("this should not happen, cmd or cmd.Process is nil")
|
||||
}
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
p.state = StateStopped
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusInternalServerError)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
+143
-4
@@ -48,6 +48,33 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
@@ -58,16 +85,19 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
||||
}
|
||||
|
||||
// test that the process unloads after the TTL
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
@@ -79,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -111,6 +141,33 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
@@ -164,3 +221,85 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
|
||||
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
|
||||
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
|
||||
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
|
||||
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
|
||||
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
|
||||
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
|
||||
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := &Process{
|
||||
state: test.currentState,
|
||||
}
|
||||
|
||||
err := p.setState(test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.state != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Second * 2)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
+39
-1
@@ -69,6 +69,19 @@ func New(config *Config) *ProxyManager {
|
||||
})
|
||||
}
|
||||
|
||||
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||
// respond with permissive OPTIONS for any endpoint
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
@@ -143,13 +156,38 @@ func (pm *ProxyManager) stopProcesses() {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
process.Stop()
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Stop()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
}
|
||||
|
||||
// Shutdown is called to shutdown all upstream processes
|
||||
// when llama-swap is shutting down.
|
||||
func (pm *ProxyManager) Shutdown() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// shutdown process in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
|
||||
@@ -254,3 +254,53 @@ func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
|
||||
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
||||
model2Config.Proxy = "http://localhost:10002/"
|
||||
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2", "model3"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Start all the processes
|
||||
var wg sync.WaitGroup
|
||||
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||
wg.Add(1)
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||
//fmt.Println(w.Code, w.Body.String())
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-time.After(time.Second)
|
||||
proxy.Shutdown()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user