Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 |
@@ -13,11 +13,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -15,7 +15,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
#platform: [intel, cuda, vulkan, cpu, musa]
|
||||
platform: [cuda, vulkan, cpu, musa]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
@@ -26,7 +26,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Docker and Podman support
|
||||
@@ -36,7 +36,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
@@ -120,16 +120,58 @@ models:
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles eliminates swapping by running multiple models at the same time
|
||||
# Groups provide advanced controls over model swapping behaviour. Using groups
|
||||
# some models can be kept loaded indefinitely, while others are swapped out.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
#
|
||||
# - models must be defined above in the Models section
|
||||
# - a model can only be a member of one group
|
||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||
groups:
|
||||
# group1 is the default behaviour of llama-swap where only one model is allowed
|
||||
# to run a time across the whole llama-swap instance
|
||||
"group1":
|
||||
# swap controls the model swapping behaviour in within the group
|
||||
# - true : only one model is allowed to run at a time
|
||||
# - false: all models can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive controls how the group affects other groups
|
||||
# - true: causes all other groups to unload their models when this group runs a model
|
||||
# - false: does not affect other groups
|
||||
exclusive: true
|
||||
|
||||
# members references the models defined above
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# models in this group are never unloaded
|
||||
"group2":
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
# (not defined above, here for example)
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
"forever":
|
||||
# setting persistent to true causes the group to never be affected by the swapping behaviour of
|
||||
# other groups. It is a shortcut to keeping some models always loaded.
|
||||
persistent: true
|
||||
|
||||
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
```
|
||||
|
||||
### Use Case Examples
|
||||
|
||||
@@ -34,6 +34,10 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(config.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
|
||||
@@ -33,14 +33,17 @@ func main() {
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
@@ -63,8 +66,11 @@ func main() {
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
// issue #41
|
||||
@@ -104,6 +110,10 @@ func main() {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
|
||||
// expose some header values for testing
|
||||
"h_content_type": c.GetHeader("Content-Type"),
|
||||
"h_content_length": c.GetHeader("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
+99
-5
@@ -3,12 +3,15 @@ package proxy
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
@@ -24,12 +27,38 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
@@ -53,16 +82,16 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
var config Config
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
@@ -77,7 +106,72 @@ func LoadConfig(path string) (*Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName, _ := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
|
||||
+96
-1
@@ -35,11 +35,31 @@ models:
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
@@ -52,7 +72,7 @@ profiles:
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
expected := Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -68,6 +88,17 @@ profiles:
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
@@ -77,6 +108,25 @@ profiles:
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,6 +137,51 @@ profiles:
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
|
||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
group2:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
_, err = LoadConfig(tempFile)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -26,6 +27,17 @@ func TestMain(m *testing.M) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@
|
||||
|
||||
this.eventSource.onmessage = (event) => {
|
||||
this.logData += event.data;
|
||||
this.logData = this.logData.slice(-1024 * 100);
|
||||
this.render();
|
||||
};
|
||||
|
||||
|
||||
+36
-26
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -93,17 +94,17 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != expectedState {
|
||||
p.proxyLogger.Warnf("swapState() Unexpected current state %s, expected %s", p.state, expectedState)
|
||||
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("swapState() Invalid state transition from %s to %s", p.state, newState)
|
||||
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
p.proxyLogger.Debugf("swapState() State transitioned from %s to %s", expectedState, newState)
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
@@ -187,7 +188,7 @@ func (p *Process) start() error {
|
||||
// Capture the exit error for later signaling
|
||||
go func() {
|
||||
exitErr := p.cmd.Wait()
|
||||
p.proxyLogger.Debugf("cmd.Wait() returned for [%s] error: %v", p.ID, exitErr)
|
||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||
p.cmdWaitChan <- exitErr
|
||||
}()
|
||||
|
||||
@@ -236,32 +237,32 @@ func (p *Process) start() error {
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
case exitErr := <-p.cmdWaitChan:
|
||||
if exitErr != nil {
|
||||
p.proxyLogger.Warnf("upstream command exited prematurely with error: %v", exitErr)
|
||||
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
|
||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
|
||||
} else {
|
||||
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
|
||||
}
|
||||
} else {
|
||||
p.proxyLogger.Warnf("upstream command exited prematurely with no error")
|
||||
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
|
||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||
return fmt.Errorf("upstream command exited prematurely with no error AND state swap failed: %v, current state: %v", err, curState)
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
|
||||
} else {
|
||||
return fmt.Errorf("upstream command exited prematurely with no error")
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("Health check passed on %s", healthURL)
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||
cancelHealthCheck()
|
||||
break loop
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
endTime, _ := checkDeadline.Deadline()
|
||||
ttl := time.Until(endTime)
|
||||
p.proxyLogger.Infof("Connection refused on %s, giving up in %.0fs", healthURL, ttl.Seconds())
|
||||
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
|
||||
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,7 +286,7 @@ func (p *Process) start() error {
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
@@ -301,13 +302,17 @@ func (p *Process) start() error {
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
p.proxyLogger.Debugf("Stopping process [%s]", p.ID)
|
||||
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
||||
|
||||
// calling Stop() when state is invalid is a no-op
|
||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
|
||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -315,7 +320,7 @@ func (p *Process) Stop() {
|
||||
p.stopCommand(5 * time.Second)
|
||||
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
|
||||
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,24 +338,24 @@ func (p *Process) Shutdown() {
|
||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("Process [%s] stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
}()
|
||||
|
||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||
defer cancelTimeout()
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
|
||||
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.terminateProcess(); err != nil {
|
||||
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
|
||||
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
|
||||
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
case err := <-p.cmdWaitChan:
|
||||
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
||||
@@ -359,24 +364,23 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
// succeeded but that's not a case llama-swap is handling for now.
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
|
||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
|
||||
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
|
||||
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
|
||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
@@ -436,6 +440,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
|
||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||
if err == nil {
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
@@ -471,6 +481,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
p.proxyLogger.Debugf("Process [%s] request %s - start: %v, total: %v",
|
||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||
p.ID, r.RequestURI, startDuration, totalTime)
|
||||
}
|
||||
|
||||
@@ -337,6 +337,6 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely with no error", err.Error())
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateFailed)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
pg.lastUsedProcess = modelID
|
||||
}
|
||||
pg.Unlock()
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses() {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
pg.stopProcesses()
|
||||
}
|
||||
|
||||
// stopProcesses stops all processes in the group
|
||||
func (pg *ProcessGroup) stopProcesses() {
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Stop()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses()
|
||||
|
||||
tests := []string{"model1", "model2"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
// make sure only one process is in the running state
|
||||
count := 0
|
||||
for _, process := range pg.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses()
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
+108
-156
@@ -26,17 +26,18 @@ const (
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentProcesses map[string]*Process
|
||||
ginEngine *gin.Engine
|
||||
config Config
|
||||
ginEngine *gin.Engine
|
||||
|
||||
// logging
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
func New(config Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
@@ -65,13 +66,20 @@ func New(config *Config) *ProxyManager {
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
currentProcesses: make(map[string]*Process),
|
||||
ginEngine: gin.New(),
|
||||
config: config,
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range config.Groups {
|
||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
@@ -200,27 +208,17 @@ func (pm *ProxyManager) StopProcesses() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
pm.stopProcesses()
|
||||
}
|
||||
|
||||
// for internal usage
|
||||
func (pm *ProxyManager) stopProcesses() {
|
||||
if len(pm.currentProcesses) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
for _, processGroup := range pm.processGroups {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
go func(processGroup *ProcessGroup) {
|
||||
defer wg.Done()
|
||||
process.Stop()
|
||||
}(process)
|
||||
processGroup.stopProcesses()
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Shutdown is called to shutdown all upstream processes
|
||||
@@ -229,18 +227,44 @@ func (pm *ProxyManager) Shutdown() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// shutdown process in parallel
|
||||
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
// Send shutdown signal to all process in groups
|
||||
for _, processGroup := range pm.processGroups {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
go func(processGroup *ProcessGroup) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
processGroup.Shutdown()
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||
for groupId, otherGroup := range pm.processGroups {
|
||||
if groupId != processGroup.id && !otherGroup.persistent {
|
||||
otherGroup.StopProcesses()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
@@ -270,79 +294,6 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(modelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
// check if model is part of the profile
|
||||
if profileName != "" {
|
||||
found := false
|
||||
for _, item := range pm.config.Profiles[profileName] {
|
||||
if item == realModelName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// exit early when already running, otherwise stop everything and swap
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// stop all running models
|
||||
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
|
||||
pm.stopProcesses()
|
||||
if profileName == "" {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
} else {
|
||||
for _, modelName := range pm.config.Profiles[profileName] {
|
||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestedProcessKey should exist due to swap
|
||||
return pm.currentProcesses[requestedProcessKey], nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
|
||||
@@ -351,13 +302,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
} else {
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||
@@ -395,31 +348,23 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
if process.config.UseModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
if profileName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -428,16 +373,14 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
@@ -451,15 +394,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Swap to the requested model
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get profile name and model name from the requested model
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
@@ -467,10 +411,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
if process.config.UseModelName != "" {
|
||||
fieldValue = process.config.UseModelName
|
||||
} else if profileName != "" {
|
||||
fieldValue = modelName
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
fieldValue = requestedModel
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
@@ -531,8 +478,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// set the content length of the body
|
||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
process.ProxyRequest(c.Writer, modifiedReq)
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||
@@ -554,14 +509,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.Header("Content-Type", "application/json")
|
||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||
|
||||
for _, process := range pm.currentProcesses {
|
||||
|
||||
// Append the process ID and State (multiple entries if profiles are being used).
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
|
||||
for _, processGroup := range pm.processGroups {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put the results under the `running` key.
|
||||
@@ -572,15 +528,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
}
|
||||
|
||||
func splitRequestedModel(requestedModel string) (string, string) {
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||
for _, group := range pm.processGroups {
|
||||
if group.HasMember(modelName) {
|
||||
return group
|
||||
}
|
||||
}
|
||||
return profileName, modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
+213
-313
@@ -8,6 +8,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,14 +17,14 @@ import (
|
||||
)
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
@@ -36,59 +37,91 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||
|
||||
}
|
||||
|
||||
// make sure there's only one loaded model
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileModel1 := ProcessKeyName("test", model1)
|
||||
profileModel2 := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1, model2},
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
Groups: map[string]GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model1"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
for modelID, requestedModel := range map[string]string{
|
||||
"model1": profileModel1,
|
||||
"model2": profileModel2,
|
||||
} {
|
||||
tests := []string{"model1", "model2"}
|
||||
for _, requestedModel := range tests {
|
||||
t.Run(requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), requestedModel)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// Test that a persistent group is not affected by the swapping behaviour of
|
||||
// other groups.
|
||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
// the forever group is persistent and should not be affected by model1
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
// make requests to load all models, loading model1 should not affect model2
|
||||
tests := []string{"model2", "model1"}
|
||||
for _, requestedModel := range tests {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelID)
|
||||
assert.Contains(t, w.Body.String(), requestedModel)
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Len(t, proxy.currentProcesses, 2)
|
||||
_, exists := proxy.currentProcesses[profileModel1]
|
||||
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||
|
||||
_, exists = proxy.currentProcesses[profileModel2]
|
||||
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// When a request for a different model comes in ProxyManager should wait until
|
||||
@@ -98,7 +131,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
@@ -106,7 +139,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
@@ -133,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
|
||||
mu.Lock()
|
||||
|
||||
results[key] = w.Body.String()
|
||||
var response map[string]string
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
results[key] = response["responseMessage"]
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
@@ -149,7 +184,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
@@ -217,51 +252,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileMemberName := ProcessKeyName("test", model1)
|
||||
profileNonMemberName := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1},
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
// actual member of profile
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
}
|
||||
|
||||
// actual model, but non-member will 404
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
@@ -273,24 +263,27 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2", "model3"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
Groups: map[string]GroupConfig{
|
||||
"test": {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2", "model3"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Start all the processes
|
||||
var wg sync.WaitGroup
|
||||
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||
for _, modelName := range []string{"model1", "model2", "model3"} {
|
||||
wg.Add(1)
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
@@ -298,11 +291,10 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load
|
||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||
//fmt.Println(w.Code, w.Body.String())
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
@@ -314,67 +306,44 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
proc, err := proxy.swapModel("model1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proc)
|
||||
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
req := httptest.NewRequest("GET", "/unload", nil)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
req = httptest.NewRequest("GET", "/unload", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
assert.Len(t, proxy.currentProcesses, 0)
|
||||
}
|
||||
|
||||
// issue 62, strip profile slug from model name
|
||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "ok")
|
||||
// give it a bit of time to stop
|
||||
<-time.After(time.Millisecond * 250)
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Shared configuration
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
LogLevel: "debug",
|
||||
})
|
||||
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
@@ -429,238 +398,127 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
})
|
||||
|
||||
t.Run("multiple models via profile", func(t *testing.T) {
|
||||
// Load more than one model.
|
||||
for _, model := range []string{"model1", "model2"} {
|
||||
profileModel := ProcessKeyName("test", model)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Simulate the browser call.
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// The JSON response must be valid.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// The response should contain 2 models.
|
||||
assert.Len(t, response.Running, 2)
|
||||
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
}
|
||||
|
||||
// Iterate through the models and check their states as well.
|
||||
for _, entry := range response.Running {
|
||||
_, exists := expectedModels[entry.Model]
|
||||
assert.True(t, exists, "unexpected model %s", entry.Model)
|
||||
assert.Equal(t, "ready", entry.State)
|
||||
delete(expectedModels, entry.Model)
|
||||
}
|
||||
|
||||
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
modelInput string
|
||||
expectModel string
|
||||
}{
|
||||
{
|
||||
name: "With Profile Prefix",
|
||||
modelInput: "test:TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
||||
},
|
||||
{
|
||||
name: "Without Profile Prefix",
|
||||
modelInput: "TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Should remain the same
|
||||
},
|
||||
}
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("TheExpectedModel"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tc.modelInput))
|
||||
assert.NoError(t, err)
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectModel, response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expectedProfile string
|
||||
expectedModel string
|
||||
}{
|
||||
{"no profile", "gpt-4", "", "gpt-4"},
|
||||
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
||||
{"only profile", "profile1:", "profile1", ""},
|
||||
{"empty model", ":gpt-4", "", "gpt-4"},
|
||||
{"empty profile", ":", "", ""},
|
||||
{"no split char", "gpt-4", "", "gpt-4"},
|
||||
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
||||
if profileName != tt.expectedProfile {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
if modelName != tt.expectedModel {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
||||
}
|
||||
|
||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||
func TestProxyManager_UseModelName(t *testing.T) {
|
||||
|
||||
upstreamModelName := "upstreamModel"
|
||||
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1"},
|
||||
},
|
||||
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
|
||||
LogLevel: "error",
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
requestedModel string
|
||||
}{
|
||||
{"useModelName over rides requested model", "model1"},
|
||||
{"useModelName over rides requested profile:model", "test:model1"},
|
||||
}
|
||||
requestedModel := "model1"
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(requestedModel))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tt.requestedModel))
|
||||
assert.NoError(t, err)
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -720,3 +578,45 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Upstream(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.Equal(t, "81", response["h_content_length"])
|
||||
assert.Equal(t, "model1", response["responseMessage"])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user