Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 |
@@ -13,11 +13,11 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v9
|
- uses: actions/stale@v9
|
||||||
with:
|
with:
|
||||||
days-before-issue-stale: 30
|
days-before-issue-stale: 14
|
||||||
days-before-issue-close: 14
|
days-before-issue-close: 14
|
||||||
stale-issue-label: "stale"
|
stale-issue-label: "stale"
|
||||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||||
days-before-pr-stale: -1
|
days-before-pr-stale: -1
|
||||||
days-before-pr-close: -1
|
days-before-pr-close: -1
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [intel, cuda, vulkan, cpu, musa]
|
#platform: [intel, cuda, vulkan, cpu, musa]
|
||||||
|
platform: [cuda, vulkan, cpu, musa]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
- ✅ Docker and Podman support
|
- ✅ Docker and Podman support
|
||||||
@@ -36,7 +36,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
|
|
||||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||||
|
|
||||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
@@ -70,6 +70,14 @@ healthCheckTimeout: 60
|
|||||||
# Valid log levels: debug, info (default), warn, error
|
# Valid log levels: debug, info (default), warn, error
|
||||||
logLevel: info
|
logLevel: info
|
||||||
|
|
||||||
|
# Automatic Port Values
|
||||||
|
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
|
||||||
|
# when you use ${PORT} you can omit a custom model.proxy value, as it will
|
||||||
|
# default to http://localhost:${PORT}
|
||||||
|
|
||||||
|
# override the default port (5800) for automatic port values
|
||||||
|
startPort: 10001
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
@@ -83,6 +91,7 @@ models:
|
|||||||
- "CUDA_VISIBLE_DEVICES=0"
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
|
# can be omitted if you use an automatic ${PORT} in cmd
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases names to use this model for
|
# aliases names to use this model for
|
||||||
@@ -109,27 +118,69 @@ models:
|
|||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
unlisted: true
|
unlisted: true
|
||||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# Docker Support (v26.1.4+ required!)
|
||||||
"docker-llama":
|
"docker-llama":
|
||||||
proxy: "http://127.0.0.1:9790"
|
proxy: "http://127.0.0.1:${PORT}"
|
||||||
cmd: >
|
cmd: >
|
||||||
docker run --name dockertest
|
docker run --name dockertest
|
||||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
# profiles eliminates swapping by running multiple models at the same time
|
# Groups provide advanced controls over model swapping behaviour. Using groups
|
||||||
|
# some models can be kept loaded indefinitely, while others are swapped out.
|
||||||
#
|
#
|
||||||
# Tips:
|
# Tips:
|
||||||
# - each model must be listening on a unique address and port
|
#
|
||||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
# - models must be defined above in the Models section
|
||||||
# - the profile will load and unload all models in the profile at the same time
|
# - a model can only be a member of one group
|
||||||
profiles:
|
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||||
coding:
|
# - see issue #109 for details
|
||||||
- "llama"
|
#
|
||||||
- "qwen-unlisted"
|
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||||
|
groups:
|
||||||
|
# group1 is the default behaviour of llama-swap where only one model is allowed
|
||||||
|
# to run a time across the whole llama-swap instance
|
||||||
|
"group1":
|
||||||
|
# swap controls the model swapping behaviour in within the group
|
||||||
|
# - true : only one model is allowed to run at a time
|
||||||
|
# - false: all models can run together, no swapping
|
||||||
|
swap: true
|
||||||
|
|
||||||
|
# exclusive controls how the group affects other groups
|
||||||
|
# - true: causes all other groups to unload their models when this group runs a model
|
||||||
|
# - false: does not affect other groups
|
||||||
|
exclusive: true
|
||||||
|
|
||||||
|
# members references the models defined above
|
||||||
|
members:
|
||||||
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
|
|
||||||
|
# models in this group are never unloaded
|
||||||
|
"group2":
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "docker-llama"
|
||||||
|
# (not defined above, here for example)
|
||||||
|
- "modelA"
|
||||||
|
- "modelB"
|
||||||
|
|
||||||
|
"forever":
|
||||||
|
# setting persistent to true causes the group to never be affected by the swapping behaviour of
|
||||||
|
# other groups. It is a shortcut to keeping some models always loaded.
|
||||||
|
persistent: true
|
||||||
|
|
||||||
|
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "forever-modelA"
|
||||||
|
- "forever-modelB"
|
||||||
|
- "forever-modelc"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use Case Examples
|
### Use Case Examples
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(config.Profiles) > 0 {
|
||||||
|
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||||
|
}
|
||||||
|
|
||||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -33,14 +33,17 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// add a wait to simulate a slow query
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
time.Sleep(wait)
|
time.Sleep(wait)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -63,8 +66,11 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
@@ -104,6 +110,10 @@ func main() {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
||||||
|
// expose some header values for testing
|
||||||
|
"h_content_type": c.GetHeader("Content-Type"),
|
||||||
|
"h_content_length": c.GetHeader("Content-Length"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
+150
-6
@@ -2,13 +2,18 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/shlex"
|
"github.com/google/shlex"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
|
||||||
type ModelConfig struct {
|
type ModelConfig struct {
|
||||||
Cmd string `yaml:"cmd"`
|
Cmd string `yaml:"cmd"`
|
||||||
Proxy string `yaml:"proxy"`
|
Proxy string `yaml:"proxy"`
|
||||||
@@ -24,15 +29,44 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
return SanitizeCommand(m.Cmd)
|
return SanitizeCommand(m.Cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GroupConfig struct {
|
||||||
|
Swap bool `yaml:"swap"`
|
||||||
|
Exclusive bool `yaml:"exclusive"`
|
||||||
|
Persistent bool `yaml:"persistent"`
|
||||||
|
Members []string `yaml:"members"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default values for GroupConfig
|
||||||
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawGroupConfig GroupConfig
|
||||||
|
defaults := rawGroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Persistent: false,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = GroupConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
LogRequests bool `yaml:"logRequests"`
|
LogRequests bool `yaml:"logRequests"`
|
||||||
LogLevel string `yaml:"logLevel"`
|
LogLevel string `yaml:"logLevel"`
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
// map aliases to actual model IDs
|
// map aliases to actual model IDs
|
||||||
aliases map[string]string
|
aliases map[string]string
|
||||||
|
|
||||||
|
// automatic port assignments
|
||||||
|
StartPort int `yaml:"startPort"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -53,31 +87,141 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (Config, error) {
|
||||||
data, err := os.ReadFile(path)
|
file, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Config{}, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
return LoadConfigFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var config Config
|
var config Config
|
||||||
err = yaml.Unmarshal(data, &config)
|
err = yaml.Unmarshal(data, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
if config.HealthCheckTimeout < 15 {
|
||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set default port ranges
|
||||||
|
if config.StartPort == 0 {
|
||||||
|
// default to 5800
|
||||||
|
config.StartPort = 5800
|
||||||
|
} else if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
// Populate the aliases map
|
// Populate the aliases map
|
||||||
config.aliases = make(map[string]string)
|
config.aliases = make(map[string]string)
|
||||||
for modelName, modelConfig := range config.Models {
|
for modelName, modelConfig := range config.Models {
|
||||||
for _, alias := range modelConfig.Aliases {
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
config.aliases[alias] = modelName
|
config.aliases[alias] = modelName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &config, nil
|
// iterate over the models and replace any ${PORT} with the next available port
|
||||||
|
// Get and sort all model IDs first, makes testing more consistent
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||||
|
|
||||||
|
// iterate over the sorted models
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
||||||
|
if modelConfig.Proxy == "" {
|
||||||
|
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
|
||||||
|
} else {
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
|
||||||
|
}
|
||||||
|
nextPort++
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
} else if modelConfig.Proxy == "" {
|
||||||
|
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
// check that members are all unique in the groups
|
||||||
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
// Check for duplicates within this group
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
// Check if member is used in another group
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
|
if config.Groups == nil {
|
||||||
|
config.Groups = make(map[string]GroupConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGroup := GroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
// if groups is empty, create a default group and put
|
||||||
|
// all models into it
|
||||||
|
if len(config.Groups) == 0 {
|
||||||
|
for modelName := range config.Models {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// iterate over existing group members and add non-grouped models into the default group
|
||||||
|
for modelName, _ := range config.Models {
|
||||||
|
foundModel := false
|
||||||
|
found:
|
||||||
|
// search for the model in existing groups
|
||||||
|
for _, groupConfig := range config.Groups {
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if member == modelName {
|
||||||
|
foundModel = true
|
||||||
|
break found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundModel {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||||
|
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||||
|
|
||||||
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
|||||||
+186
-1
@@ -3,6 +3,7 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -35,11 +36,32 @@ models:
|
|||||||
aliases:
|
aliases:
|
||||||
- "m2"
|
- "m2"
|
||||||
checkEndpoint: "/"
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 15
|
||||||
profiles:
|
profiles:
|
||||||
test:
|
test:
|
||||||
- model1
|
- model1
|
||||||
- model2
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
@@ -52,7 +74,8 @@ profiles:
|
|||||||
t.Fatalf("Failed to load config: %v", err)
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := &Config{
|
expected := Config{
|
||||||
|
StartPort: 5800,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -68,6 +91,18 @@ profiles:
|
|||||||
Env: nil,
|
Env: nil,
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
},
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: nil,
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
@@ -77,6 +112,25 @@ profiles:
|
|||||||
"m1": "model1",
|
"m1": "model1",
|
||||||
"model-one": "model1",
|
"model-one": "model1",
|
||||||
"m2": "model2",
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,6 +141,63 @@ profiles:
|
|||||||
assert.Equal(t, "model1", realname)
|
assert.Equal(t, "model1", realname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
group2:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
`
|
||||||
|
// Load the config and verify
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
|
||||||
|
// a Contains as order of the map is not guaranteed
|
||||||
|
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
aliases:
|
||||||
|
- m1
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
aliases:
|
||||||
|
- m1
|
||||||
|
- m2
|
||||||
|
`
|
||||||
|
// Load the config and verify
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
|
||||||
|
// this is a contains because it could be `model1` or `model2` depending on the order
|
||||||
|
// go decided on the order of the map
|
||||||
|
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||||
|
}
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
config := &ModelConfig{
|
config := &ModelConfig{
|
||||||
Cmd: `python model1.py \
|
Cmd: `python model1.py \
|
||||||
@@ -174,3 +285,77 @@ func TestConfig_SanitizeCommand(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, args)
|
assert.Nil(t, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||||
|
content := ``
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
})
|
||||||
|
t.Run("User specific port ranges", func(t *testing.T) {
|
||||||
|
content := `startPort: 1000`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1000, config.StartPort)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid start port", func(t *testing.T) {
|
||||||
|
content := `startPort: abcd`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start port must be greater than 1", func(t *testing.T) {
|
||||||
|
content := `startPort: -99`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Automatic port assignments", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 5800
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
model2:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
proxy: "http://172.11.22.33:${PORT}"
|
||||||
|
model3:
|
||||||
|
cmd: svr --port 1999
|
||||||
|
proxy: "http://1.2.3.4:1999"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
||||||
|
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
||||||
|
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
||||||
|
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port 111
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
nextTestPort int = 12000
|
nextTestPort int = 12000
|
||||||
portMutex sync.Mutex
|
portMutex sync.Mutex
|
||||||
|
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if the binary exists
|
// Check if the binary exists
|
||||||
@@ -26,6 +27,17 @@ func TestMain(m *testing.M) {
|
|||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
switch os.Getenv("LOG_LEVEL") {
|
||||||
|
case "debug":
|
||||||
|
testLogger.SetLogLevel(LevelDebug)
|
||||||
|
case "warn":
|
||||||
|
testLogger.SetLogLevel(LevelWarn)
|
||||||
|
case "info":
|
||||||
|
testLogger.SetLogLevel(LevelInfo)
|
||||||
|
default:
|
||||||
|
testLogger.SetLogLevel(LevelWarn)
|
||||||
|
}
|
||||||
|
|
||||||
m.Run()
|
m.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -170,6 +170,7 @@
|
|||||||
|
|
||||||
this.eventSource.onmessage = (event) => {
|
this.eventSource.onmessage = (event) => {
|
||||||
this.logData += event.data;
|
this.logData += event.data;
|
||||||
|
this.logData = this.logData.slice(-1024 * 100);
|
||||||
this.render();
|
this.render();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
+36
-26
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -93,17 +94,17 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
defer p.stateMutex.Unlock()
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
if p.state != expectedState {
|
if p.state != expectedState {
|
||||||
p.proxyLogger.Warnf("swapState() Unexpected current state %s, expected %s", p.state, expectedState)
|
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||||
return p.state, ErrExpectedStateMismatch
|
return p.state, ErrExpectedStateMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isValidTransition(p.state, newState) {
|
if !isValidTransition(p.state, newState) {
|
||||||
p.proxyLogger.Warnf("swapState() Invalid state transition from %s to %s", p.state, newState)
|
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||||
return p.state, ErrInvalidStateTransition
|
return p.state, ErrInvalidStateTransition
|
||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
p.proxyLogger.Debugf("swapState() State transitioned from %s to %s", expectedState, newState)
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +188,7 @@ func (p *Process) start() error {
|
|||||||
// Capture the exit error for later signaling
|
// Capture the exit error for later signaling
|
||||||
go func() {
|
go func() {
|
||||||
exitErr := p.cmd.Wait()
|
exitErr := p.cmd.Wait()
|
||||||
p.proxyLogger.Debugf("cmd.Wait() returned for [%s] error: %v", p.ID, exitErr)
|
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||||
p.cmdWaitChan <- exitErr
|
p.cmdWaitChan <- exitErr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -236,32 +237,32 @@ func (p *Process) start() error {
|
|||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
case exitErr := <-p.cmdWaitChan:
|
case exitErr := <-p.cmdWaitChan:
|
||||||
if exitErr != nil {
|
if exitErr != nil {
|
||||||
p.proxyLogger.Warnf("upstream command exited prematurely with error: %v", exitErr)
|
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||||
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
|
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
|
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Warnf("upstream command exited prematurely with no error")
|
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||||
return fmt.Errorf("upstream command exited prematurely with no error AND state swap failed: %v, current state: %v", err, curState)
|
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("upstream command exited prematurely with no error")
|
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
p.proxyLogger.Infof("Health check passed on %s", healthURL)
|
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||||
cancelHealthCheck()
|
cancelHealthCheck()
|
||||||
break loop
|
break loop
|
||||||
} else {
|
} else {
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
endTime, _ := checkDeadline.Deadline()
|
endTime, _ := checkDeadline.Deadline()
|
||||||
ttl := time.Until(endTime)
|
ttl := time.Until(endTime)
|
||||||
p.proxyLogger.Infof("Connection refused on %s, giving up in %.0fs", healthURL, ttl.Seconds())
|
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds())
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
|
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -285,7 +286,7 @@ func (p *Process) start() error {
|
|||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
|
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -301,13 +302,17 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
p.proxyLogger.Debugf("Stopping process [%s]", p.ID)
|
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
|
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,7 +320,7 @@ func (p *Process) Stop() {
|
|||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(5 * time.Second)
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
|
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,24 +338,24 @@ func (p *Process) Shutdown() {
|
|||||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||||
stopStartTime := time.Now()
|
stopStartTime := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
p.proxyLogger.Debugf("Process [%s] stopCommand took %v", p.ID, time.Since(stopStartTime))
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||||
defer cancelTimeout()
|
defer cancelTimeout()
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
|
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
if err := p.terminateProcess(); err != nil {
|
||||||
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
|
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
|
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID)
|
||||||
p.cmd.Process.Kill()
|
p.cmd.Process.Kill()
|
||||||
case err := <-p.cmdWaitChan:
|
case err := <-p.cmdWaitChan:
|
||||||
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
||||||
@@ -359,24 +364,23 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
// succeeded but that's not a case llama-swap is handling for now.
|
// succeeded but that's not a case llama-swap is handling for now.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
if errno, ok := err.(syscall.Errno); ok {
|
||||||
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
|
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
|
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID)
|
||||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
|
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID)
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
|
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 500 * time.Millisecond,
|
Timeout: 500 * time.Millisecond,
|
||||||
}
|
}
|
||||||
@@ -436,6 +440,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header = r.Header.Clone()
|
req.Header = r.Header.Clone()
|
||||||
|
|
||||||
|
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
req.ContentLength = contentLength
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
@@ -471,6 +481,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
totalTime := time.Since(requestBeginTime)
|
totalTime := time.Since(requestBeginTime)
|
||||||
p.proxyLogger.Debugf("Process [%s] request %s - start: %v, total: %v",
|
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||||
p.ID, r.RequestURI, startDuration, totalTime)
|
p.ID, r.RequestURI, startDuration, totalTime)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -337,6 +337,6 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||||
process.healthCheckLoopInterval = time.Second // make it faster
|
process.healthCheckLoopInterval = time.Second // make it faster
|
||||||
err := process.start()
|
err := process.start()
|
||||||
assert.Equal(t, "upstream command exited prematurely with no error", err.Error())
|
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||||
assert.Equal(t, process.CurrentState(), StateFailed)
|
assert.Equal(t, process.CurrentState(), StateFailed)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProcessGroup struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
config Config
|
||||||
|
id string
|
||||||
|
swap bool
|
||||||
|
exclusive bool
|
||||||
|
persistent bool
|
||||||
|
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
upstreamLogger *LogMonitor
|
||||||
|
|
||||||
|
// map of current processes
|
||||||
|
processes map[string]*Process
|
||||||
|
lastUsedProcess string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
|
groupConfig, ok := config.Groups[id]
|
||||||
|
if !ok {
|
||||||
|
panic("Unable to find configuration for group id: " + id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pg := &ProcessGroup{
|
||||||
|
id: id,
|
||||||
|
config: config,
|
||||||
|
swap: groupConfig.Swap,
|
||||||
|
exclusive: groupConfig.Exclusive,
|
||||||
|
persistent: groupConfig.Persistent,
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
upstreamLogger: upstreamLogger,
|
||||||
|
processes: make(map[string]*Process),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a Process for each member in the group
|
||||||
|
for _, modelID := range groupConfig.Members {
|
||||||
|
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||||
|
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||||
|
pg.processes[modelID] = process
|
||||||
|
}
|
||||||
|
|
||||||
|
return pg
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyRequest proxies a request to the specified model
|
||||||
|
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||||
|
if !pg.HasMember(modelID) {
|
||||||
|
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pg.swap {
|
||||||
|
pg.Lock()
|
||||||
|
if pg.lastUsedProcess != modelID {
|
||||||
|
if pg.lastUsedProcess != "" {
|
||||||
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
|
}
|
||||||
|
pg.lastUsedProcess = modelID
|
||||||
|
}
|
||||||
|
pg.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||||
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) StopProcesses() {
|
||||||
|
pg.Lock()
|
||||||
|
defer pg.Unlock()
|
||||||
|
pg.stopProcesses()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopProcesses stops all processes in the group
|
||||||
|
func (pg *ProcessGroup) stopProcesses() {
|
||||||
|
if len(pg.processes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop Processes in parallel
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
process.Stop()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) Shutdown() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
process.Shutdown()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
|
"model4": getTestSimpleResponderConfig("model4"),
|
||||||
|
"model5": getTestSimpleResponderConfig("model5"),
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
"G2": {
|
||||||
|
Swap: false,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model3", "model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||||
|
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||||
|
assert.True(t, pg.HasMember("model5"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGroup_HasMember(t *testing.T) {
|
||||||
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
|
assert.True(t, pg.HasMember("model1"))
|
||||||
|
assert.True(t, pg.HasMember("model2"))
|
||||||
|
assert.False(t, pg.HasMember("model3"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||||
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses()
|
||||||
|
|
||||||
|
tests := []string{"model1", "model2"}
|
||||||
|
|
||||||
|
for _, modelName := range tests {
|
||||||
|
t.Run(modelName, func(t *testing.T) {
|
||||||
|
reqBody := `{"x", "y"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
|
// make sure only one process is in the running state
|
||||||
|
count := 0
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
if process.CurrentState() == StateReady {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses()
|
||||||
|
|
||||||
|
tests := []string{"model3", "model4"}
|
||||||
|
|
||||||
|
for _, modelName := range tests {
|
||||||
|
t.Run(modelName, func(t *testing.T) {
|
||||||
|
reqBody := `{"x", "y"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure all the processes are running
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
}
|
||||||
|
}
|
||||||
+108
-156
@@ -26,17 +26,18 @@ const (
|
|||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config *Config
|
config Config
|
||||||
currentProcesses map[string]*Process
|
ginEngine *gin.Engine
|
||||||
ginEngine *gin.Engine
|
|
||||||
|
|
||||||
// logging
|
// logging
|
||||||
proxyLogger *LogMonitor
|
proxyLogger *LogMonitor
|
||||||
upstreamLogger *LogMonitor
|
upstreamLogger *LogMonitor
|
||||||
muxLogger *LogMonitor
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
|
processGroups map[string]*ProcessGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *Config) *ProxyManager {
|
func New(config Config) *ProxyManager {
|
||||||
// set up loggers
|
// set up loggers
|
||||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
@@ -65,13 +66,20 @@ func New(config *Config) *ProxyManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
currentProcesses: make(map[string]*Process),
|
ginEngine: gin.New(),
|
||||||
ginEngine: gin.New(),
|
|
||||||
|
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
muxLogger: stdoutLogger,
|
muxLogger: stdoutLogger,
|
||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the process groups
|
||||||
|
for groupID := range config.Groups {
|
||||||
|
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||||
|
pm.processGroups[groupID] = processGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
@@ -200,27 +208,17 @@ func (pm *ProxyManager) StopProcesses() {
|
|||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
pm.stopProcesses()
|
|
||||||
}
|
|
||||||
|
|
||||||
// for internal usage
|
|
||||||
func (pm *ProxyManager) stopProcesses() {
|
|
||||||
if len(pm.currentProcesses) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop Processes in parallel
|
// stop Processes in parallel
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, process := range pm.currentProcesses {
|
for _, processGroup := range pm.processGroups {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Stop()
|
processGroup.stopProcesses()
|
||||||
}(process)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
pm.currentProcesses = make(map[string]*Process)
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called to shutdown all upstream processes
|
// Shutdown is called to shutdown all upstream processes
|
||||||
@@ -229,18 +227,44 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// shutdown process in parallel
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, process := range pm.currentProcesses {
|
// Send shutdown signal to all process in groups
|
||||||
|
for _, processGroup := range pm.processGroups {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Shutdown()
|
processGroup.Shutdown()
|
||||||
}(process)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||||
|
// de-alias the real model name and get a real one
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if processGroup.exclusive {
|
||||||
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||||
|
for groupId, otherGroup := range pm.processGroups {
|
||||||
|
if groupId != processGroup.id && !otherGroup.persistent {
|
||||||
|
otherGroup.StopProcesses()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return processGroup, realModelName, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := []interface{}{}
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
@@ -270,79 +294,6 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|
||||||
pm.Lock()
|
|
||||||
defer pm.Unlock()
|
|
||||||
|
|
||||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
|
||||||
|
|
||||||
if profileName != "" {
|
|
||||||
if _, found := pm.config.Profiles[profileName]; !found {
|
|
||||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// de-alias the real model name and get a real one
|
|
||||||
realModelName, found := pm.config.RealModelName(modelName)
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if model is part of the profile
|
|
||||||
if profileName != "" {
|
|
||||||
found := false
|
|
||||||
for _, item := range pm.config.Profiles[profileName] {
|
|
||||||
if item == realModelName {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
|
||||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
|
||||||
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
|
|
||||||
return process, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop all running models
|
|
||||||
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
|
|
||||||
pm.stopProcesses()
|
|
||||||
if profileName == "" {
|
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
|
||||||
pm.currentProcesses[processKey] = process
|
|
||||||
} else {
|
|
||||||
for _, modelName := range pm.config.Profiles[profileName] {
|
|
||||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
|
||||||
}
|
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
|
||||||
pm.currentProcesses[processKey] = process
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestedProcessKey should exist due to swap
|
|
||||||
return pm.currentProcesses[requestedProcessKey], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
requestedModel := c.Param("model_id")
|
requestedModel := c.Param("model_id")
|
||||||
|
|
||||||
@@ -351,13 +302,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
if err != nil {
|
||||||
} else {
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
// rewrite the path
|
return
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rewrite the path
|
||||||
|
c.Request.URL.Path = c.Param("upstreamPath")
|
||||||
|
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||||
@@ -395,31 +348,23 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
process, err := pm.swapModel(requestedModel)
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
if process.config.UseModelName != "" {
|
useModelName := pm.config.Models[realModelName].UseModelName
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
if useModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
|
||||||
if profileName != "" {
|
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
@@ -428,16 +373,14 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
c.Request.Header.Del("transfer-encoding")
|
c.Request.Header.Del("transfer-encoding")
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Parse multipart form
|
// Parse multipart form
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
@@ -451,15 +394,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Swap to the requested model
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
process, err := pm.swapModel(requestedModel)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get profile name and model name from the requested model
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
// Copy all form values
|
// Copy all form values
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
@@ -467,10 +411,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
fieldValue := value
|
fieldValue := value
|
||||||
// If this is the model field and we have a profile, use just the model name
|
// If this is the model field and we have a profile, use just the model name
|
||||||
if key == "model" {
|
if key == "model" {
|
||||||
if process.config.UseModelName != "" {
|
// # issue #69 allow custom model names to be sent to upstream
|
||||||
fieldValue = process.config.UseModelName
|
useModelName := pm.config.Models[realModelName].UseModelName
|
||||||
} else if profileName != "" {
|
|
||||||
fieldValue = modelName
|
if useModelName != "" {
|
||||||
|
fieldValue = useModelName
|
||||||
|
} else {
|
||||||
|
fieldValue = requestedModel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
field, err := multipartWriter.CreateFormField(key)
|
field, err := multipartWriter.CreateFormField(key)
|
||||||
@@ -531,8 +478,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.Header = c.Request.Header.Clone()
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// set the content length of the body
|
||||||
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||||
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
process.ProxyRequest(c.Writer, modifiedReq)
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||||
@@ -554,14 +509,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
context.Header("Content-Type", "application/json")
|
context.Header("Content-Type", "application/json")
|
||||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||||
|
|
||||||
for _, process := range pm.currentProcesses {
|
for _, processGroup := range pm.processGroups {
|
||||||
|
for _, process := range processGroup.processes {
|
||||||
// Append the process ID and State (multiple entries if profiles are being used).
|
if process.CurrentState() == StateReady {
|
||||||
runningProcesses = append(runningProcesses, gin.H{
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
"model": process.ID,
|
"model": process.ID,
|
||||||
"state": process.state,
|
"state": process.state,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put the results under the `running` key.
|
// Put the results under the `running` key.
|
||||||
@@ -572,15 +528,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProcessKeyName(groupName, modelName string) string {
|
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
for _, group := range pm.processGroups {
|
||||||
}
|
if group.HasMember(modelName) {
|
||||||
|
return group
|
||||||
func splitRequestedModel(requestedModel string) (string, string) {
|
}
|
||||||
profileName, modelName := "", requestedModel
|
|
||||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
|
||||||
profileName = requestedModel[:idx]
|
|
||||||
modelName = requestedModel[idx+1:]
|
|
||||||
}
|
}
|
||||||
return profileName, modelName
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
+213
-313
@@ -8,6 +8,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,14 +17,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
@@ -36,59 +37,91 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
proxy.HandlerFunc(w, req)
|
proxy.HandlerFunc(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
|
||||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure there's only one loaded model
|
|
||||||
assert.Len(t, proxy.currentProcesses, 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
model1 := "path1/model1"
|
|
||||||
model2 := "path2/model2"
|
|
||||||
|
|
||||||
profileModel1 := ProcessKeyName("test", model1)
|
|
||||||
profileModel2 := ProcessKeyName("test", model2)
|
|
||||||
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
model1: getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
model2: getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {model1, model2},
|
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
Groups: map[string]GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model1"},
|
||||||
|
},
|
||||||
|
"G2": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
for modelID, requestedModel := range map[string]string{
|
tests := []string{"model1", "model2"}
|
||||||
"model1": profileModel1,
|
for _, requestedModel := range tests {
|
||||||
"model2": profileModel2,
|
t.Run(requestedModel, func(t *testing.T) {
|
||||||
} {
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), requestedModel)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there's two loaded models
|
||||||
|
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||||
|
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that a persistent group is not affected by the swapping behaviour of
|
||||||
|
// other groups.
|
||||||
|
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
// the forever group is persistent and should not be affected by model1
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
// make requests to load all models, loading model1 should not affect model2
|
||||||
|
tests := []string{"model2", "model1"}
|
||||||
|
for _, requestedModel := range tests {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.HandlerFunc(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelID)
|
assert.Contains(t, w.Body.String(), requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure there's two loaded models
|
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||||
assert.Len(t, proxy.currentProcesses, 2)
|
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||||
_, exists := proxy.currentProcesses[profileModel1]
|
|
||||||
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
|
||||||
|
|
||||||
_, exists = proxy.currentProcesses[profileModel2]
|
|
||||||
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a request for a different model comes in ProxyManager should wait until
|
// When a request for a different model comes in ProxyManager should wait until
|
||||||
@@ -98,7 +131,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
t.Skip("skipping slow test")
|
t.Skip("skipping slow test")
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
@@ -106,7 +139,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
@@ -133,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
|
||||||
results[key] = w.Body.String()
|
var response map[string]string
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
results[key] = response["responseMessage"]
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}(key)
|
}(key)
|
||||||
|
|
||||||
@@ -149,7 +184,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||||
config := &Config{
|
config := Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
@@ -217,51 +252,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
|
||||||
|
|
||||||
model1 := "path1/model1"
|
|
||||||
model2 := "path2/model2"
|
|
||||||
|
|
||||||
profileMemberName := ProcessKeyName("test", model1)
|
|
||||||
profileNonMemberName := ProcessKeyName("test", model2)
|
|
||||||
|
|
||||||
config := &Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
model1: getTestSimpleResponderConfig("model1"),
|
|
||||||
model2: getTestSimpleResponderConfig("model2"),
|
|
||||||
},
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {model1},
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
// actual member of profile
|
|
||||||
{
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "model1")
|
|
||||||
}
|
|
||||||
|
|
||||||
// actual model, but non-member will 404
|
|
||||||
{
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_Shutdown(t *testing.T) {
|
func TestProxyManager_Shutdown(t *testing.T) {
|
||||||
// make broken model configurations
|
// make broken model configurations
|
||||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||||
@@ -273,24 +263,27 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||||
model3Config.Proxy = "http://localhost:10003/"
|
model3Config.Proxy = "http://localhost:10003/"
|
||||||
|
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1", "model2", "model3"},
|
|
||||||
},
|
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": model1Config,
|
"model1": model1Config,
|
||||||
"model2": model2Config,
|
"model2": model2Config,
|
||||||
"model3": model3Config,
|
"model3": model3Config,
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
Groups: map[string]GroupConfig{
|
||||||
|
"test": {
|
||||||
|
Swap: false,
|
||||||
|
Members: []string{"model1", "model2", "model3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
|
|
||||||
// Start all the processes
|
// Start all the processes
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
for _, modelName := range []string{"model1", "model2", "model3"} {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(modelName string) {
|
go func(modelName string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@@ -298,11 +291,10 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// send a request to trigger the proxy to load
|
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.HandlerFunc(w, req)
|
||||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||||
//fmt.Println(w.Code, w.Body.String())
|
|
||||||
}(modelName)
|
}(modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,67 +306,44 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Unload(t *testing.T) {
|
func TestProxyManager_Unload(t *testing.T) {
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
proc, err := proxy.swapModel("model1")
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||||
assert.NoError(t, err)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
assert.NotNil(t, proc)
|
|
||||||
|
|
||||||
assert.Len(t, proxy.currentProcesses, 1)
|
|
||||||
req := httptest.NewRequest("GET", "/unload", nil)
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||||
|
req = httptest.NewRequest("GET", "/unload", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
assert.Len(t, proxy.currentProcesses, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// issue 62, strip profile slug from model name
|
// give it a bit of time to stop
|
||||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
<-time.After(time.Millisecond * 250)
|
||||||
config := &Config{
|
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
|
||||||
},
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses()
|
|
||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
assert.Contains(t, w.Body.String(), "ok")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||||
|
|
||||||
// Shared configuration
|
// Shared configuration
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
Profiles: map[string][]string{
|
LogLevel: "debug",
|
||||||
"test": {"model1", "model2"},
|
})
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define a helper struct to parse the JSON response.
|
// Define a helper struct to parse the JSON response.
|
||||||
type RunningResponse struct {
|
type RunningResponse struct {
|
||||||
@@ -429,238 +398,127 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
// Is the model loaded?
|
// Is the model loaded?
|
||||||
assert.Equal(t, "ready", response.Running[0].State)
|
assert.Equal(t, "ready", response.Running[0].State)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("multiple models via profile", func(t *testing.T) {
|
|
||||||
// Load more than one model.
|
|
||||||
for _, model := range []string{"model1", "model2"} {
|
|
||||||
profileModel := ProcessKeyName("test", model)
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate the browser call.
|
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.HandlerFunc(w, req)
|
|
||||||
|
|
||||||
var response RunningResponse
|
|
||||||
|
|
||||||
// The JSON response must be valid.
|
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
|
||||||
|
|
||||||
// The response should contain 2 models.
|
|
||||||
assert.Len(t, response.Running, 2)
|
|
||||||
|
|
||||||
expectedModels := map[string]struct{}{
|
|
||||||
"model1": {},
|
|
||||||
"model2": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate through the models and check their states as well.
|
|
||||||
for _, entry := range response.Running {
|
|
||||||
_, exists := expectedModels[entry.Model]
|
|
||||||
assert.True(t, exists, "unexpected model %s", entry.Model)
|
|
||||||
assert.Equal(t, "ready", entry.State)
|
|
||||||
delete(expectedModels, entry.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
|
||||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"TheExpectedModel"},
|
|
||||||
},
|
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
testCases := []struct {
|
// Create a buffer with multipart form data
|
||||||
name string
|
var b bytes.Buffer
|
||||||
modelInput string
|
w := multipart.NewWriter(&b)
|
||||||
expectModel string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "With Profile Prefix",
|
|
||||||
modelInput: "test:TheExpectedModel",
|
|
||||||
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Without Profile Prefix",
|
|
||||||
modelInput: "TheExpectedModel",
|
|
||||||
expectModel: "TheExpectedModel", // Should remain the same
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
// Add the model field
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
fw, err := w.CreateFormField("model")
|
||||||
// Create a buffer with multipart form data
|
assert.NoError(t, err)
|
||||||
var b bytes.Buffer
|
_, err = fw.Write([]byte("TheExpectedModel"))
|
||||||
w := multipart.NewWriter(&b)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Add the model field
|
// Add a file field
|
||||||
fw, err := w.CreateFormField("model")
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_, err = fw.Write([]byte(tc.modelInput))
|
// Generate random content length between 10 and 20
|
||||||
assert.NoError(t, err)
|
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||||
|
content := make([]byte, contentLength)
|
||||||
|
_, err = fw.Write(content)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
// Add a file field
|
// Create the request with the multipart form data
|
||||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
assert.NoError(t, err)
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
// Generate random content length between 10 and 20
|
rec := httptest.NewRecorder()
|
||||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
proxy.HandlerFunc(rec, req)
|
||||||
content := make([]byte, contentLength)
|
|
||||||
_, err = fw.Write(content)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
w.Close()
|
|
||||||
|
|
||||||
// Create the request with the multipart form data
|
// Verify the response
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
var response map[string]string
|
||||||
rec := httptest.NewRecorder()
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||||
proxy.HandlerFunc(rec, req)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||||
// Verify the response
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
||||||
var response map[string]string
|
|
||||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tc.expectModel, response["model"])
|
|
||||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
requestedModel string
|
|
||||||
expectedProfile string
|
|
||||||
expectedModel string
|
|
||||||
}{
|
|
||||||
{"no profile", "gpt-4", "", "gpt-4"},
|
|
||||||
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
|
||||||
{"only profile", "profile1:", "profile1", ""},
|
|
||||||
{"empty model", ":gpt-4", "", "gpt-4"},
|
|
||||||
{"empty profile", ":", "", ""},
|
|
||||||
{"no split char", "gpt-4", "", "gpt-4"},
|
|
||||||
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
|
||||||
if profileName != tt.expectedProfile {
|
|
||||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
|
||||||
}
|
|
||||||
if modelName != tt.expectedModel {
|
|
||||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
func TestProxyManager_UseModelName(t *testing.T) {
|
func TestProxyManager_UseModelName(t *testing.T) {
|
||||||
|
|
||||||
upstreamModelName := "upstreamModel"
|
upstreamModelName := "upstreamModel"
|
||||||
|
|
||||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||||
modelConfig.UseModelName = upstreamModelName
|
modelConfig.UseModelName = upstreamModelName
|
||||||
|
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1"},
|
|
||||||
},
|
|
||||||
|
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": modelConfig,
|
"model1": modelConfig,
|
||||||
},
|
},
|
||||||
|
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
tests := []struct {
|
requestedModel := "model1"
|
||||||
description string
|
|
||||||
requestedModel string
|
|
||||||
}{
|
|
||||||
{"useModelName over rides requested model", "model1"},
|
|
||||||
{"useModelName over rides requested profile:model", "test:model1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||||
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
w := httptest.NewRecorder()
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.HandlerFunc(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||||
|
})
|
||||||
|
|
||||||
})
|
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
||||||
}
|
// Create a buffer with multipart form data
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&b)
|
||||||
|
|
||||||
for _, tt := range tests {
|
// Add the model field
|
||||||
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
fw, err := w.CreateFormField("model")
|
||||||
// Create a buffer with multipart form data
|
assert.NoError(t, err)
|
||||||
var b bytes.Buffer
|
_, err = fw.Write([]byte(requestedModel))
|
||||||
w := multipart.NewWriter(&b)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Add the model field
|
// Add a file field
|
||||||
fw, err := w.CreateFormField("model")
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_, err = fw.Write([]byte(tt.requestedModel))
|
_, err = fw.Write([]byte("test"))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
// Add a file field
|
// Create the request with the multipart form data
|
||||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
assert.NoError(t, err)
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
_, err = fw.Write([]byte("test"))
|
rec := httptest.NewRecorder()
|
||||||
assert.NoError(t, err)
|
proxy.HandlerFunc(rec, req)
|
||||||
w.Close()
|
|
||||||
|
|
||||||
// Create the request with the multipart form data
|
// Verify the response
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
var response map[string]string
|
||||||
rec := httptest.NewRecorder()
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||||
proxy.HandlerFunc(rec, req)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, upstreamModelName, response["model"])
|
||||||
// Verify the response
|
})
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
|
||||||
var response map[string]string
|
|
||||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, upstreamModelName, response["model"])
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||||
config := &Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -720,3 +578,45 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_Upstream(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
var response map[string]string
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
assert.Equal(t, "81", response["h_content_length"])
|
||||||
|
assert.Equal(t, "model1", response["responseMessage"])
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user