Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b |
+13
-1
@@ -15,4 +15,16 @@ builds:
|
|||||||
- goos: freebsd
|
- goos: freebsd
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
|
|
||||||
|
# use zip format for windows
|
||||||
|
archives:
|
||||||
|
- id: default
|
||||||
|
format: tar.gz
|
||||||
|
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
|
builds_info:
|
||||||
|
group: root
|
||||||
|
owner: root
|
||||||
|
format_overrides:
|
||||||
|
- goos: windows
|
||||||
|
format: zip
|
||||||
@@ -1,4 +1,9 @@
|
|||||||

|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
@@ -70,7 +75,14 @@ logRequests: true
|
|||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
# multiline for readability
|
||||||
|
cmd: >
|
||||||
|
llama-server --port 8999
|
||||||
|
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
|
# environment variables to pass to the command
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
@@ -91,16 +103,9 @@ models:
|
|||||||
# default: 0 = never unload model
|
# default: 0 = never unload model
|
||||||
ttl: 60
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
# `useModelName` overrides the model name in the request
|
||||||
# environment variables to pass to the command
|
# and sends a specific name to the upstream server
|
||||||
env:
|
useModelName: "qwen:qwq"
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
|
||||||
|
|
||||||
# multiline for readability
|
|
||||||
cmd: >
|
|
||||||
llama-server --port 8999
|
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
|
||||||
proxy: http://127.0.0.1:8999
|
|
||||||
|
|
||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
# unlisted models do not show up in /v1/models or /upstream lists
|
||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
@@ -117,7 +122,7 @@ models:
|
|||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
# profiles eliminates swapping by running multiple models at the same time
|
||||||
#
|
#
|
||||||
# Tips:
|
# Tips:
|
||||||
# - each model must be listening on a unique address and port
|
# - each model must be listening on a unique address and port
|
||||||
@@ -125,8 +130,8 @@ models:
|
|||||||
# - the profile will load and unload all models in the profile at the same time
|
# - the profile will load and unload all models in the profile at the same time
|
||||||
profiles:
|
profiles:
|
||||||
coding:
|
coding:
|
||||||
- "qwen"
|
|
||||||
- "llama"
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use Case Examples
|
### Use Case Examples
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type ModelConfig struct {
|
|||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
UnloadAfter int `yaml:"ttl"`
|
UnloadAfter int `yaml:"ttl"`
|
||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
|
UseModelName string `yaml:"useModelName"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
|||||||
+100
-112
@@ -30,11 +30,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
logMonitor *LogMonitor
|
logMonitor *LogMonitor
|
||||||
healthCheckTimeout int
|
|
||||||
|
healthCheckTimeout int
|
||||||
|
healthCheckLoopInterval time.Duration
|
||||||
|
|
||||||
lastRequestHandled time.Time
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
@@ -54,51 +56,57 @@ type Process struct {
|
|||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
logMonitor: logMonitor,
|
logMonitor: logMonitor,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
state: StateStopped,
|
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||||
shutdownCtx: ctx,
|
state: StateStopped,
|
||||||
shutdownCancel: cancel,
|
shutdownCtx: ctx,
|
||||||
|
shutdownCancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) setState(newState ProcessState) error {
|
// custom error types for swapping state
|
||||||
// enforce valid state transitions
|
var (
|
||||||
invalidTransition := false
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
if p.state == StateStopped {
|
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||||
// stopped -> starting
|
)
|
||||||
if newState != StateStarting {
|
|
||||||
invalidTransition = true
|
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||||
}
|
// and an error if the swap failed.
|
||||||
} else if p.state == StateStarting {
|
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||||
// starting -> ready | failed | stopping
|
p.stateMutex.Lock()
|
||||||
if newState != StateReady && newState != StateFailed && newState != StateStopping {
|
defer p.stateMutex.Unlock()
|
||||||
invalidTransition = true
|
|
||||||
}
|
if p.state != expectedState {
|
||||||
} else if p.state == StateReady {
|
return p.state, ErrExpectedStateMismatch
|
||||||
// ready -> stopping
|
|
||||||
if newState != StateStopping {
|
|
||||||
invalidTransition = true
|
|
||||||
}
|
|
||||||
} else if p.state == StateStopping {
|
|
||||||
// stopping -> stopped | shutdown
|
|
||||||
if newState != StateStopped && newState != StateShutdown {
|
|
||||||
invalidTransition = true
|
|
||||||
}
|
|
||||||
} else if p.state == StateFailed || p.state == StateShutdown {
|
|
||||||
invalidTransition = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if invalidTransition {
|
if !isValidTransition(p.state, newState) {
|
||||||
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
|
return p.state, ErrInvalidStateTransition
|
||||||
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
return nil
|
return p.state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to encapsulate transition rules
|
||||||
|
func isValidTransition(from, to ProcessState) bool {
|
||||||
|
switch from {
|
||||||
|
case StateStopped:
|
||||||
|
return to == StateStarting
|
||||||
|
case StateStarting:
|
||||||
|
return to == StateReady || to == StateFailed || to == StateStopping
|
||||||
|
case StateReady:
|
||||||
|
return to == StateStopping
|
||||||
|
case StateStopping:
|
||||||
|
return to == StateStopped || to == StateShutdown
|
||||||
|
case StateFailed, StateShutdown:
|
||||||
|
return false // No transitions allowed from these states
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) CurrentState() ProcessState {
|
func (p *Process) CurrentState() ProcessState {
|
||||||
@@ -116,56 +124,33 @@ func (p *Process) start() error {
|
|||||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
// multiple start() calls will wait for the one that is actually starting to
|
|
||||||
// complete before proceeding.
|
|
||||||
// ===========
|
|
||||||
curState := p.CurrentState()
|
|
||||||
|
|
||||||
if curState == StateReady {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if curState == StateStarting {
|
|
||||||
p.waitStarting.Wait()
|
|
||||||
|
|
||||||
if state := p.CurrentState(); state != StateReady {
|
|
||||||
return fmt.Errorf("start() failed current state: %v", state)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// ===========
|
|
||||||
|
|
||||||
// There is the possibility of a hard to replicate race condition where
|
|
||||||
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
|
|
||||||
// below, it's value has changed!
|
|
||||||
|
|
||||||
p.stateMutex.Lock()
|
|
||||||
defer p.stateMutex.Unlock()
|
|
||||||
|
|
||||||
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
|
|
||||||
// to transition from to StateReady
|
|
||||||
|
|
||||||
if p.state != StateStopped {
|
|
||||||
if p.state == StateReady {
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p.setState(StateStarting); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.waitStarting.Add(1)
|
|
||||||
defer p.waitStarting.Done()
|
|
||||||
|
|
||||||
args, err := p.config.SanitizedCommand()
|
args, err := p.config.SanitizedCommand()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||||
|
if err == ErrExpectedStateMismatch {
|
||||||
|
// already starting, just wait for it to complete and expect
|
||||||
|
// it to be be in the Ready start after. If not, return an error
|
||||||
|
if curState == StateStarting {
|
||||||
|
p.waitStarting.Wait()
|
||||||
|
if state := p.CurrentState(); state == StateReady {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("processes was in state %v when start() was called", curState)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.waitStarting.Add(1)
|
||||||
|
defer p.waitStarting.Done()
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.Command(args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.logMonitor
|
p.cmd.Stdout = p.logMonitor
|
||||||
p.cmd.Stderr = p.logMonitor
|
p.cmd.Stderr = p.logMonitor
|
||||||
@@ -173,8 +158,14 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.setState(StateFailed)
|
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
|
err, curState, swapErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
return fmt.Errorf("start() failed: %v", err)
|
return fmt.Errorf("start() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,13 +200,16 @@ func (p *Process) start() error {
|
|||||||
)
|
)
|
||||||
defer cancelHealthCheck()
|
defer cancelHealthCheck()
|
||||||
|
|
||||||
// Health check loop
|
|
||||||
loop:
|
loop:
|
||||||
|
// Ready Check loop
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-checkDeadline.Done():
|
case <-checkDeadline.Done():
|
||||||
p.setState(StateFailed)
|
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||||
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
|
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||||
|
}
|
||||||
case <-p.shutdownCtx.Done():
|
case <-p.shutdownCtx.Done():
|
||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
default:
|
default:
|
||||||
@@ -233,7 +227,7 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-time.After(5 * time.Second)
|
<-time.After(p.healthCheckLoopInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,7 +238,7 @@ func (p *Process) start() error {
|
|||||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||||
|
|
||||||
for range time.Tick(time.Second) {
|
for range time.Tick(time.Second) {
|
||||||
if p.state != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,26 +254,28 @@ func (p *Process) start() error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.setState(StateReady)
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
p.stateMutex.Lock()
|
|
||||||
defer p.stateMutex.Unlock()
|
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
if err := p.setState(StateStopping); err != nil {
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
|
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
// stop the process with a graceful exit timeout
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
|
|
||||||
if err := p.setState(StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
|
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,19 +283,9 @@ func (p *Process) Stop() {
|
|||||||
// of time for any inflight requests to complete before shutting down. If the Process
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
// is in the state of starting, it will cancel it and shut it down
|
// is in the state of starting, it will cancel it and shut it down
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
|
|
||||||
p.shutdownCancel()
|
p.shutdownCancel()
|
||||||
|
|
||||||
p.stateMutex.Lock()
|
|
||||||
defer p.stateMutex.Unlock()
|
|
||||||
p.setState(StateStopping)
|
|
||||||
|
|
||||||
// 5 seconds to stop the process
|
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
if err := p.setState(StateShutdown); err != nil {
|
p.state = StateShutdown
|
||||||
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
|
|
||||||
}
|
|
||||||
p.setState(StateShutdown)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
@@ -318,7 +304,9 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
if err := p.terminateProcess(); err != nil {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! failed to gracefully terminate process [%s]: %v\n", p.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func (p *Process) terminateProcess() error {
|
||||||
|
return p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Process) terminateProcess() error {
|
||||||
|
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
||||||
|
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
+26
-21
@@ -225,30 +225,32 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetState(t *testing.T) {
|
func TestProcess_SwapState(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
currentState ProcessState
|
currentState ProcessState
|
||||||
|
expectedState ProcessState
|
||||||
newState ProcessState
|
newState ProcessState
|
||||||
expectedError error
|
expectedError error
|
||||||
expectedResult ProcessState
|
expectedResult ProcessState
|
||||||
}{
|
}{
|
||||||
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
|
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||||
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
|
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||||
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
|
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
||||||
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
|
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||||
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
|
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||||
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
|
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||||
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
|
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
|
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||||
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
|
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
||||||
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
|
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||||
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
|
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
||||||
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
|
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||||
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
|
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
||||||
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
|
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
||||||
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
|
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
|
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||||
|
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -257,7 +259,7 @@ func TestSetState(t *testing.T) {
|
|||||||
state: test.currentState,
|
state: test.currentState,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := p.setState(test.newState)
|
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||||
if err != nil && test.expectedError == nil {
|
if err != nil && test.expectedError == nil {
|
||||||
t.Errorf("Unexpected error: %v", err)
|
t.Errorf("Unexpected error: %v", err)
|
||||||
} else if err == nil && test.expectedError != nil {
|
} else if err == nil && test.expectedError != nil {
|
||||||
@@ -268,8 +270,8 @@ func TestSetState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.state != test.expectedResult {
|
if resultState != test.expectedResult {
|
||||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
|
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -290,11 +292,14 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
healthCheckTTLSeconds := 30
|
healthCheckTTLSeconds := 30
|
||||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||||
|
|
||||||
|
// make it a lot faster
|
||||||
|
process.healthCheckLoopInterval = time.Second
|
||||||
|
|
||||||
// start a goroutine to simulate a shutdown
|
// start a goroutine to simulate a shutdown
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
<-time.After(time.Second * 2)
|
<-time.After(time.Millisecond * 500)
|
||||||
process.Shutdown()
|
process.Shutdown()
|
||||||
}()
|
}()
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|||||||
+47
-19
@@ -72,14 +72,27 @@ func New(config *Config) *ProxyManager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
// see: issue: #81, #77 and #42 for CORS issues
|
||||||
// respond with permissive OPTIONS for any endpoint
|
// respond with permissive OPTIONS for any endpoint
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
|
||||||
|
// set this for all requests
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == "OPTIONS" {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
||||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
// allow whatever the client requested by default
|
||||||
c.AbortWithStatus(204)
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||||
|
c.Header("Access-Control-Allow-Headers", headers)
|
||||||
|
} else {
|
||||||
|
c.Header(
|
||||||
|
"Access-Control-Allow-Headers",
|
||||||
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
c.Header("Access-Control-Max-Age", "86400")
|
||||||
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
@@ -96,7 +109,7 @@ func New(config *Config) *ProxyManager {
|
|||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler)
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||||
|
|
||||||
@@ -351,12 +364,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
process, err := pm.swapModel(requestedModel)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
|
|
||||||
// strip
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
|
if process.config.UseModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||||
@@ -366,17 +388,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
|
||||||
c.Request.Header.Del("transfer-encoding")
|
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
|
||||||
|
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
|
c.Request.Header.Del("transfer-encoding")
|
||||||
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
// Create a new buffer for the reconstructed request
|
// Create a new buffer for the reconstructed request
|
||||||
var requestBuffer bytes.Buffer
|
var requestBuffer bytes.Buffer
|
||||||
@@ -410,8 +434,12 @@ func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
|
|||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
fieldValue := value
|
fieldValue := value
|
||||||
// If this is the model field and we have a profile, use just the model name
|
// If this is the model field and we have a profile, use just the model name
|
||||||
if key == "model" && profileName != "" {
|
if key == "model" {
|
||||||
fieldValue = modelName
|
if process.config.UseModelName != "" {
|
||||||
|
fieldValue = process.config.UseModelName
|
||||||
|
} else if profileName != "" {
|
||||||
|
fieldValue = modelName
|
||||||
|
}
|
||||||
}
|
}
|
||||||
field, err := multipartWriter.CreateFormField(key)
|
field, err := multipartWriter.CreateFormField(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -532,3 +532,202 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requestedModel string
|
||||||
|
expectedProfile string
|
||||||
|
expectedModel string
|
||||||
|
}{
|
||||||
|
{"no profile", "gpt-4", "", "gpt-4"},
|
||||||
|
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
||||||
|
{"only profile", "profile1:", "profile1", ""},
|
||||||
|
{"empty model", ":gpt-4", "", "gpt-4"},
|
||||||
|
{"empty profile", ":", "", ""},
|
||||||
|
{"no split char", "gpt-4", "", "gpt-4"},
|
||||||
|
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
||||||
|
if profileName != tt.expectedProfile {
|
||||||
|
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||||
|
}
|
||||||
|
if modelName != tt.expectedModel {
|
||||||
|
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
|
func TestProxyManager_UseModelName(t *testing.T) {
|
||||||
|
|
||||||
|
upstreamModelName := "upstreamModel"
|
||||||
|
|
||||||
|
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||||
|
modelConfig.UseModelName = upstreamModelName
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1"},
|
||||||
|
},
|
||||||
|
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": modelConfig,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
description string
|
||||||
|
requestedModel string
|
||||||
|
}{
|
||||||
|
{"useModelName over rides requested model", "model1"},
|
||||||
|
{"useModelName over rides requested profile:model", "test:model1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
||||||
|
// Create a buffer with multipart form data
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&b)
|
||||||
|
|
||||||
|
// Add the model field
|
||||||
|
fw, err := w.CreateFormField("model")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte(tt.requestedModel))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a file field
|
||||||
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
// Create the request with the multipart form data
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(rec, req)
|
||||||
|
|
||||||
|
// Verify the response
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var response map[string]string
|
||||||
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, upstreamModelName, response["model"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogRequests: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
requestHeaders map[string]string
|
||||||
|
expectedStatus int
|
||||||
|
expectedHeaders map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "OPTIONS with no headers",
|
||||||
|
method: "OPTIONS",
|
||||||
|
expectedStatus: http.StatusNoContent,
|
||||||
|
expectedHeaders: map[string]string{
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OPTIONS with specific headers",
|
||||||
|
method: "OPTIONS",
|
||||||
|
requestHeaders: map[string]string{
|
||||||
|
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusNoContent,
|
||||||
|
expectedHeaders: map[string]string{
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-OPTIONS request",
|
||||||
|
method: "GET",
|
||||||
|
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||||
|
for k, v := range tt.requestHeaders {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ginEngine.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||||
|
|
||||||
|
for header, expectedValue := range tt.expectedHeaders {
|
||||||
|
assert.Equal(t, expectedValue, w.Header().Get(header))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_CORSHeadersInRegularRequest(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogRequests: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
// Test that CORS headers are present in regular POST requests
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ginEngine.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user