Compare commits

...

9 Commits

Author SHA1 Message Date
Grigorii Khvatskii 4c3aa40564 add graceful process termination on windows (#82) 2025-03-25 15:26:33 -07:00
Benson Wong 84e2c07a7e Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: 

- `Access-Control-Allow-Origin: *` is set for all requests 
- for pre-flight OPTIONS requests
  - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS`
  - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set
  - set `Access-Control-Max-Age: 86400` to that may improve performance 
- Add CORS tests to the proxy-manager
2025-03-25 15:24:43 -07:00
Benson Wong 680af28bcc Allow very permissive CORS headers (#77) 2025-03-20 15:50:21 -07:00
Benson Wong d94db42ffe fix bug checking incorrect error 2025-03-20 15:49:36 -07:00
Benson Wong 93cd83c55c add override for windows (#76) 2025-03-20 13:23:04 -07:00
Benson Wong 5565fca3ac add some badges to README 2025-03-19 11:25:06 -07:00
Benson Wong d625ab8d92 Refactor process state management (#70) (#73)
* add isValidStateTransition helper function
* Replace Process.setState() with Process.swapState()
* Refactor locking logic in Process
2025-03-15 17:14:03 -07:00
Benson Wong a3f82c140b tidy up config examples in README 2025-03-15 10:36:45 -07:00
Benson Wong 5c97299e7b Add support for sending a custom model name to upstream (#69) (#71)
* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
2025-03-14 21:07:52 -07:00
9 changed files with 427 additions and 166 deletions
+13 -1
View File
@@ -15,4 +15,16 @@ builds:
- goos: freebsd - goos: freebsd
goarch: arm64 goarch: arm64
- goos: windows - goos: windows
goarch: arm64 goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+18 -13
View File
@@ -1,4 +1,9 @@
![llama-swap header image](header.jpeg) ![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
@@ -70,7 +75,14 @@ logRequests: true
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
"llama": "llama":
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf # multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match # where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
@@ -91,16 +103,9 @@ models:
# default: 0 = never unload model # default: 0 = never unload model
ttl: 60 ttl: 60
"qwen": # `useModelName` overrides the model name in the request
# environment variables to pass to the command # and sends a specific name to the upstream server
env: useModelName: "qwen:qwq"
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists # unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal # but they can still be requested as normal
@@ -117,7 +122,7 @@ models:
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles make it easy to managing multi model (and gpu) configurations. # profiles eliminates swapping by running multiple models at the same time
# #
# Tips: # Tips:
# - each model must be listening on a unique address and port # - each model must be listening on a unique address and port
@@ -125,8 +130,8 @@ models:
# - the profile will load and unload all models in the profile at the same time # - the profile will load and unload all models in the profile at the same time
profiles: profiles:
coding: coding:
- "qwen"
- "llama" - "llama"
- "qwen-unlisted"
``` ```
### Use Case Examples ### Use Case Examples
+1
View File
@@ -17,6 +17,7 @@ type ModelConfig struct {
CheckEndpoint string `yaml:"checkEndpoint"` CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
+100 -112
View File
@@ -30,11 +30,13 @@ const (
) )
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor logMonitor *LogMonitor
healthCheckTimeout int
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time lastRequestHandled time.Time
@@ -54,51 +56,57 @@ type Process struct {
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
logMonitor: logMonitor, logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
state: StateStopped, healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
shutdownCtx: ctx, state: StateStopped,
shutdownCancel: cancel, shutdownCtx: ctx,
shutdownCancel: cancel,
} }
} }
func (p *Process) setState(newState ProcessState) error { // custom error types for swapping state
// enforce valid state transitions var (
invalidTransition := false ErrExpectedStateMismatch = errors.New("expected state mismatch")
if p.state == StateStopped { ErrInvalidStateTransition = errors.New("invalid state transition")
// stopped -> starting )
if newState != StateStarting {
invalidTransition = true // swapState performs a compare and swap of the state atomically. It returns the current state
} // and an error if the swap failed.
} else if p.state == StateStarting { func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
// starting -> ready | failed | stopping p.stateMutex.Lock()
if newState != StateReady && newState != StateFailed && newState != StateStopping { defer p.stateMutex.Unlock()
invalidTransition = true
} if p.state != expectedState {
} else if p.state == StateReady { return p.state, ErrExpectedStateMismatch
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
} }
if invalidTransition { if !isValidTransition(p.state, newState) {
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState)) return p.state, ErrInvalidStateTransition
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
} }
p.state = newState p.state = newState
return nil return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown:
return false // No transitions allowed from these states
}
return false
} }
func (p *Process) CurrentState() ProcessState { func (p *Process) CurrentState() ProcessState {
@@ -116,56 +124,33 @@ func (p *Process) start() error {
return fmt.Errorf("can not start(), upstream proxy missing") return fmt.Errorf("can not start(), upstream proxy missing")
} }
// multiple start() calls will wait for the one that is actually starting to
// complete before proceeding.
// ===========
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
// ===========
// There is the possibility of a hard to replicate race condition where
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
// below, it's value has changed!
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
// to transition from to StateReady
if p.state != StateStopped {
if p.state == StateReady {
return nil
} else {
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
}
}
if err := p.setState(StateStarting); err != nil {
return err
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand() args, err := p.config.SanitizedCommand()
if err != nil { if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err) return fmt.Errorf("unable to get sanitized command: %v", err)
} }
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.logMonitor p.cmd.Stderr = p.logMonitor
@@ -173,8 +158,14 @@ func (p *Process) start() error {
err = p.cmd.Start() err = p.cmd.Start()
// Set process state to failed
if err != nil { if err != nil {
p.setState(StateFailed) if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
@@ -209,13 +200,16 @@ func (p *Process) start() error {
) )
defer cancelHealthCheck() defer cancelHealthCheck()
// Health check loop
loop: loop:
// Ready Check loop
for { for {
select { select {
case <-checkDeadline.Done(): case <-checkDeadline.Done():
p.setState(StateFailed) if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds()) return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
default: default:
@@ -233,7 +227,7 @@ func (p *Process) start() error {
} }
} }
<-time.After(5 * time.Second) <-time.After(p.healthCheckLoopInterval)
} }
} }
@@ -244,7 +238,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) { for range time.Tick(time.Second) {
if p.state != StateReady { if p.CurrentState() != StateReady {
return return
} }
@@ -260,26 +254,28 @@ func (p *Process) start() error {
}() }()
} }
return p.setState(StateReady) if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
} }
func (p *Process) Stop() { func (p *Process) Stop() {
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil { if curState, err := p.swapState(StateReady, StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
return return
} }
// stop the process with a graceful exit timeout // stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if err := p.setState(StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err)) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
} }
} }
@@ -287,19 +283,9 @@ func (p *Process) Stop() {
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down // is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel() p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil { p.state = StateShutdown
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -318,7 +304,9 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
return return
} }
p.cmd.Process.Signal(syscall.SIGTERM) if err := p.terminateProcess(); err != nil {
fmt.Fprintf(p.logMonitor, "!!! failed to gracefully terminate process [%s]: %v\n", p.ID, err)
}
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
+9
View File
@@ -0,0 +1,9 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
+14
View File
@@ -0,0 +1,14 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+26 -21
View File
@@ -225,30 +225,32 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
} }
} }
func TestSetState(t *testing.T) { func TestProcess_SwapState(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
currentState ProcessState currentState ProcessState
expectedState ProcessState
newState ProcessState newState ProcessState
expectedError error expectedError error
expectedResult ProcessState expectedResult ProcessState
}{ }{
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting}, {"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateReady, nil, StateReady}, {"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed}, {"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping}, {"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping}, {"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped}, {"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown}, {"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped}, {"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting}, {"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady}, {"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady}, {"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping}, {"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed}, {"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed}, {"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown}, {"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown}, {"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
} }
for _, test := range tests { for _, test := range tests {
@@ -257,7 +259,7 @@ func TestSetState(t *testing.T) {
state: test.currentState, state: test.currentState,
} }
err := p.setState(test.newState) resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil { } else if err == nil && test.expectedError != nil {
@@ -268,8 +270,8 @@ func TestSetState(t *testing.T) {
} }
} }
if p.state != test.expectedResult { if resultState != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state) t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
} }
}) })
} }
@@ -290,11 +292,14 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor) process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown // start a goroutine to simulate a shutdown
var wg sync.WaitGroup var wg sync.WaitGroup
go func() { go func() {
defer wg.Done() defer wg.Done()
<-time.After(time.Second * 2) <-time.After(time.Millisecond * 500)
process.Shutdown() process.Shutdown()
}() }()
wg.Add(1) wg.Add(1)
+47 -19
View File
@@ -72,14 +72,27 @@ func New(config *Config) *ProxyManager {
}) })
} }
// see: https://github.com/mostlygeek/llama-swap/issues/42 // see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// set this for all requests
c.Header("Access-Control-Allow-Origin", "*")
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") // allow whatever the client requested by default
c.AbortWithStatus(204) if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
c.Header("Access-Control-Allow-Headers", headers)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return return
} }
c.Next() c.Next()
@@ -96,7 +109,7 @@ func New(config *Config) *ProxyManager {
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler) pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -351,12 +364,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
} }
if process, err := pm.swapModel(requestedModel); err != nil { process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} else { }
// strip // issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel) profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" { if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName) bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
@@ -366,17 +388,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
} }
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
} }
func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed // We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request // Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer var requestBuffer bytes.Buffer
@@ -410,8 +434,12 @@ func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
for _, value := range values { for _, value := range values {
fieldValue := value fieldValue := value
// If this is the model field and we have a profile, use just the model name // If this is the model field and we have a profile, use just the model name
if key == "model" && profileName != "" { if key == "model" {
fieldValue = modelName if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
} }
field, err := multipartWriter.CreateFormField(key) field, err := multipartWriter.CreateFormField(key)
if err != nil { if err != nil {
+199
View File
@@ -532,3 +532,202 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
}) })
} }
} }
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config)
defer proxy.StopProcesses()
tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
}
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
func TestProxyManager_CORSHeadersInRegularRequest(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
}
proxy := New(config)
defer proxy.StopProcesses()
// Test that CORS headers are present in regular POST requests
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
}