Compare commits

...

17 Commits

Author SHA1 Message Date
Benson Wong ab5a048584 bug fix with missing early return statements fix #112 2025-05-04 07:34:24 -07:00
Benson Wong d9a1ddea0d Truncate web logs to 100K characters (#111)
* set log limit to 100K in browser
2025-05-02 23:43:21 -07:00
Benson Wong e7ab024ca0 small locking optimization 2025-05-02 23:18:07 -07:00
Benson Wong 448ccae959 Introduce Groups Feature (#107)
Groups allows more control over swapping behaviour when a model is requested. The new groups feature provides three ways to control swapping: within the group, swapping out other groups or keep the models in the group loaded persistently (never swapped out). 

Closes #96, #99 and #106.
2025-05-02 22:35:38 -07:00
Benson Wong ec0348e431 Reduce stale time for issues 2025-04-29 21:16:34 -07:00
Benson Wong 06eda7f591 tag all process logs with its ID (#103)
Makes identifying Process of log messages easier
2025-04-25 12:58:25 -07:00
Benson Wong 5fad24c16f Make checkHealthTimeout Interruptable during startup (#102)
interrupt and exit Process.start() early if the upstream process exits prematurely or unexpectedly.
2025-04-24 14:39:33 -07:00
Benson Wong 8404244fab Moderate security update for golang/x/net -> v0.38.0 2025-04-24 09:58:40 -07:00
Benson Wong 712cd01081 fix confusing INFO message [no ci] 2025-04-24 09:56:20 -07:00
Benson Wong 1f7aa359b1 Update header image
AI has finally made my dreams of llamas in funny clothing and stuck in
a claw machine waiting to be picked come true!
2025-04-23 13:02:12 -07:00
Benson Wong b138d6cf25 fix starhistory in README 2025-04-15 20:23:46 -07:00
Benson Wong fb7c808082 add timing for Process start, stop, total request time (#91) 2025-04-14 14:34:59 -07:00
Benson Wong a7e640b0f7 add aider example 2025-04-07 12:37:14 -07:00
Benson Wong 593604dfdc Show proxy and upstream logs in separate columns in logs UI 2025-04-05 10:36:54 -07:00
Benson Wong b8f888f864 Logging Improvements (#88)
This change revamps the internal logging architecture to be more flexible and descriptive. Previously all logs from both llama-swap and upstream services were mixed together. This makes it harder to troubleshoot and identify problems. This PR adds these new endpoints: 

- `/logs/stream/proxy` - just llama-swap's logs
- `/logs/stream/upstream` - stdout output from the upstream server
2025-04-04 21:01:33 -07:00
Benson Wong 192b2ae621 Remove no longer needed test 2025-04-04 14:46:01 -07:00
Benson Wong b7f8cb5094 Limit Access-Control-Allow-Origin to OPTIONS preflight requests #85 2025-04-04 14:44:35 -07:00
23 changed files with 1544 additions and 673 deletions
+4 -4
View File
@@ -13,11 +13,11 @@ jobs:
steps: steps:
- uses: actions/stale@v9 - uses: actions/stale@v9
with: with:
days-before-issue-stale: 30 days-before-issue-stale: 14
days-before-issue-close: 14 days-before-issue-close: 14
stale-issue-label: "stale" stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
+63 -21
View File
@@ -1,10 +1,8 @@
![llama-swap header image](header.jpeg) ![llama-swap header image](header2.png)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total) ![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml) ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap) ![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server. llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
@@ -28,7 +26,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58)) - `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741)) - ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - ✅ Docker and Podman support
@@ -38,7 +36,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used. In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml ## config.yaml
@@ -69,8 +67,8 @@ models:
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false # Valid log levels: debug, info (default), warn, error
logRequests: true logLevel: info
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
@@ -122,16 +120,58 @@ models:
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles eliminates swapping by running multiple models at the same time # Groups provide advanced controls over model swapping behaviour. Using groups
# some models can be kept loaded indefinitely, while others are swapped out.
# #
# Tips: # Tips:
# - each model must be listening on a unique address and port #
# - the model name is in this format: "profile_name:model", like "coding:qwen" # - models must be defined above in the Models section
# - the profile will load and unload all models in the profile at the same time # - a model can only be a member of one group
profiles: # - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
coding: # - see issue #109 for details
- "llama" #
- "qwen-unlisted" # NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 is the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap controls the model swapping behaviour in within the group
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive controls how the group affects other groups
# - true: causes all other groups to unload their models when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
members:
- "llama"
- "qwen-unlisted"
# models in this group are never unloaded
"group2":
swap: false
exclusive: false
members:
- "docker-llama"
# (not defined above, here for example)
- "modelA"
- "modelB"
"forever":
# setting persistent to true causes the group to never be affected by the swapping behaviour of
# other groups. It is a shortcut to keeping some models always loaded.
persistent: true
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
``` ```
### Use Case Examples ### Use Case Examples
@@ -221,9 +261,15 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
# streams logs # streams combined logs
curl -Ns 'http://host/logs/stream' curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
@@ -265,8 +311,4 @@ WantedBy=multi-user.target
## Star History ## Star History
<picture> [![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests # Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 15 healthCheckTimeout: 90
# Log HTTP requests helpful for troubleshoot, defaults to False # valid log levels: debug, info (default), warn, error
logRequests: true logLevel: debug
models: models:
"llama": "llama":
+153
View File
@@ -0,0 +1,153 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -0,0 +1,28 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -0,0 +1,26 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
+49
View File
@@ -0,0 +1,49 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
+1 -1
View File
@@ -37,7 +37,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
+2
View File
@@ -86,6 +86,8 @@ golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 351 KiB

+4
View File
@@ -34,6 +34,10 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if len(config.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
if mode := os.Getenv("GIN_MODE"); mode != "" { if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode) gin.SetMode(mode)
} else { } else {
+100 -5
View File
@@ -3,12 +3,15 @@ package proxy
import ( import (
"fmt" "fmt"
"os" "os"
"sort"
"strings" "strings"
"github.com/google/shlex" "github.com/google/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct { type ModelConfig struct {
Cmd string `yaml:"cmd"` Cmd string `yaml:"cmd"`
Proxy string `yaml:"proxy"` Proxy string `yaml:"proxy"`
@@ -24,11 +27,38 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd) return SanitizeCommand(m.Cmd)
} }
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
Models map[string]ModelConfig `yaml:"models"` LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// map aliases to actual model IDs // map aliases to actual model IDs
aliases map[string]string aliases map[string]string
@@ -52,16 +82,16 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
} }
} }
func LoadConfig(path string) (*Config, error) { func LoadConfig(path string) (Config, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return Config{}, err
} }
var config Config var config Config
err = yaml.Unmarshal(data, &config) err = yaml.Unmarshal(data, &config)
if err != nil { if err != nil {
return nil, err return Config{}, err
} }
if config.HealthCheckTimeout < 15 { if config.HealthCheckTimeout < 15 {
@@ -76,7 +106,72 @@ func LoadConfig(path string) (*Config, error) {
} }
} }
return &config, nil config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
} }
func SanitizeCommand(cmdStr string) ([]string, error) { func SanitizeCommand(cmdStr string) ([]string, error) {
+96 -1
View File
@@ -35,11 +35,31 @@ models:
aliases: aliases:
- "m2" - "m2"
checkEndpoint: "/" checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
checkEndpoint: "/"
healthCheckTimeout: 15 healthCheckTimeout: 15
profiles: profiles:
test: test:
- model1 - model1
- model2 - model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
` `
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -52,7 +72,7 @@ profiles:
t.Fatalf("Failed to load config: %v", err) t.Fatalf("Failed to load config: %v", err)
} }
expected := &Config{ expected := Config{
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -68,6 +88,17 @@ profiles:
Env: nil, Env: nil,
CheckEndpoint: "/", CheckEndpoint: "/",
}, },
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: nil,
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
CheckEndpoint: "/",
},
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{ Profiles: map[string][]string{
@@ -77,6 +108,25 @@ profiles:
"m1": "model1", "m1": "model1",
"model-one": "model1", "model-one": "model1",
"m2": "model2", "m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
}, },
} }
@@ -87,6 +137,51 @@ profiles:
assert.Equal(t, "model1", realname) assert.Equal(t, "model1", realname)
} }
func TestConfig_GroupMemberIsUnique(t *testing.T) {
// Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
tempFile := filepath.Join(tempDir, "config.yaml")
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write temporary file: %v", err)
}
// Load the config and verify
_, err = LoadConfig(tempFile)
assert.NotNil(t, err)
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{ config := &ModelConfig{
Cmd: `python model1.py \ Cmd: `python model1.py \
+12
View File
@@ -14,6 +14,7 @@ import (
var ( var (
nextTestPort int = 12000 nextTestPort int = 12000
portMutex sync.Mutex portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout)
) )
// Check if the binary exists // Check if the binary exists
@@ -26,6 +27,17 @@ func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
switch os.Getenv("LOG_LEVEL") {
case "debug":
testLogger.SetLogLevel(LevelDebug)
case "warn":
testLogger.SetLogLevel(LevelWarn)
case "info":
testLogger.SetLogLevel(LevelInfo)
default:
testLogger.SetLogLevel(LevelWarn)
}
m.Run() m.Run()
} }
+189 -75
View File
@@ -12,32 +12,65 @@
flex-direction: column; flex-direction: column;
font-family: "Courier New", Courier, monospace; font-family: "Courier New", Courier, monospace;
} }
#log-controls { .log-container {
margin: 0.5em;
display: flex; display: flex;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
}
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
#log-stream {
flex: 1; flex: 1;
gap: 0.5em;
margin: 0.5em; margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex;
flex-direction: column;
flex: 1;
min-width: 0;
transition: flex 0.3s ease;
}
.log-column.minimized {
flex: 0.1;
max-width: 50px;
border: 1px solid #777;
color: green;
}
.log-controls {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1;
padding: 1em; padding: 1em;
background: #f4f4f4; background: #f4f4f4;
overflow-y: auto; overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */ white-space: pre-wrap;
word-wrap: break-word; /* Ensures long words wrap */ word-wrap: break-word;
min-height: 0;
} }
.regex-error { .regex-error {
background-color: #ff0000 !important; background-color: #ff0000 !important;
} }
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */ /* Dark mode styles */
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
body { body {
@@ -45,101 +78,182 @@
color: #fff; color: #fff;
} }
#log-stream { .log-stream {
background: #444; background: #444;
color: #fff; color: #fff;
} }
#log-controls input { .log-controls input {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
#log-controls button { .log-controls button {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
} }
</style> </style>
</head> </head>
<body> <body>
<pre id="log-stream">Waiting for logs...</pre> <div class="log-container">
<div id="log-controls"> <div class="log-column">
<input type="text" id="filter-input" placeholder="regex filter"> <h2>Proxy Logs</h2>
<button id="clear-button">clear</button> <div class="log-controls">
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div> </div>
<script> <script>
const logStream = document.getElementById('log-stream'); class LogStream {
const filterInput = document.getElementById('filter-input'); constructor(streamElement, filterInput, clearButton, endpoint) {
var logData = ""; this.streamElement = streamElement;
let regexFilter = null; this.filterInput = filterInput;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
function setupEventSource() { this.initialize();
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
} }
}
// poor-ai's react ¯\_(ツ)_/¯ initialize() {
function render() { this.filterInput.addEventListener('input', () => this.updateFilter());
if (regexFilter) { this.clearButton.addEventListener('click', () => {
const lines = logData.split('\n'); this.filterInput.value = "";
const filteredLines = lines.filter(line => { this.regexFilter = null;
return regexFilter === null || regexFilter.test(line); this.render();
}); });
this.setupEventSource();
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
} }
logStream.scrollTop = logStream.scrollHeight; setupEventSource() {
} if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.logData = this.logData.slice(-1024 * 100);
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
};
// Initial connection
connect();
}
render() {
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try { try {
regexFilter = new RegExp(pattern); this.regexFilter = new RegExp(pattern);
} catch (e) { } catch (e) {
console.error("Invalid regex pattern:", e); console.error("Invalid regex pattern:", e);
regexFilter = null; this.regexFilter = null;
filterInput.classList.add('regex-error'); this.filterInput.classList.add('regex-error');
return return;
} }
} else {
regexFilter = null;
}
render(); this.render();
}
} }
filterInput.addEventListener('input', updateFilter); // Initialize both log streams
document.getElementById('clear-button').addEventListener('click', () => { document.addEventListener('DOMContentLoaded', () => {
filterInput.value = ""; new LogStream(
regexFilter = null; document.getElementById('proxy-log-stream'),
render(); document.getElementById('proxy-filter-input'),
document.getElementById('proxy-clear-button'),
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
});
});
}); });
setupEventSource();
updateFilter();
</script> </script>
</body> </body>
</html> </html>
+90
View File
@@ -2,11 +2,21 @@ package proxy
import ( import (
"container/ring" "container/ring"
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
) )
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool clients map[chan []byte]bool
mu sync.RWMutex mu sync.RWMutex
@@ -15,6 +25,10 @@ type LogMonitor struct {
// typically this can be os.Stdout // typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels
level LogLevel
prefix string
} }
func NewLogMonitor() *LogMonitor { func NewLogMonitor() *LogMonitor {
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool), clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo,
prefix: "",
} }
} }
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
} }
} }
} }
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+83 -29
View File
@@ -30,10 +30,15 @@ const (
) )
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor
// for p.cmd.Wait() select { ... }
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration healthCheckLoopInterval time.Duration
@@ -53,13 +58,15 @@ type Process struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
logMonitor: logMonitor, cmdWaitChan: make(chan error, 1),
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped, state: StateStopped,
@@ -68,6 +75,11 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
} }
} }
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state // custom error types for swapping state
var ( var (
ErrExpectedStateMismatch = errors.New("expected state mismatch") ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -81,14 +93,17 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
defer p.stateMutex.Unlock() defer p.stateMutex.Unlock()
if p.state != expectedState { if p.state != expectedState {
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
return p.state, ErrExpectedStateMismatch return p.state, ErrExpectedStateMismatch
} }
if !isValidTransition(p.state, newState) { if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
return p.state, ErrInvalidStateTransition return p.state, ErrInvalidStateTransition
} }
p.state = newState p.state = newState
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
return p.state, nil return p.state, nil
} }
@@ -152,8 +167,8 @@ func (p *Process) start() error {
defer p.waitStarting.Done() defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.logMonitor p.cmd.Stderr = p.processLogger
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
@@ -169,6 +184,13 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
// Capture the exit error for later signaling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage: // One of three things can happen at this stage:
// 1. The command exits unexpectedly // 1. The command exits unexpectedly
// 2. The health check fails // 2. The health check fails
@@ -212,17 +234,34 @@ func (p *Process) start() error {
} }
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely but successfully")
}
}
default: default:
if err := p.checkHealthEndpoint(healthURL); err == nil { if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
cancelHealthCheck() cancelHealthCheck()
break loop break loop
} else { } else {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds()) p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds())
} else { } else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err) p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err)
} }
} }
} }
@@ -246,7 +285,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter) p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
} }
@@ -262,12 +301,17 @@ func (p *Process) start() error {
} }
func (p *Process) Stop() { func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if curState, err := p.swapState(StateReady, StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState) p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return return
} }
@@ -275,7 +319,7 @@ func (p *Process) Stop() {
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState) p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
} }
} }
@@ -291,49 +335,51 @@ func (p *Process) Shutdown() {
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL. // If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) { func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout() defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID) p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { if err := p.terminateProcess(); err != nil {
fmt.Fprintf(p.logMonitor, "!!! failed to gracefully terminate process [%s]: %v\n", p.ID, err) p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err)
} }
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID) p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID)
p.cmd.Process.Kill() p.cmd.Process.Kill()
case err := <-sigtermNormal: case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno) p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID) p.proxyLogger.Infof("<%s> Process stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID) p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID)
} else { } else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode()) p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
} }
} else { } else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err) p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err)
} }
} }
} }
} }
func (p *Process) checkHealthEndpoint(healthURL string) error { func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{ client := &http.Client{
Timeout: 500 * time.Millisecond, Timeout: 500 * time.Millisecond,
} }
@@ -358,6 +404,8 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable // prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState() currentState := p.CurrentState()
@@ -374,11 +422,13 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand // start the process on demand
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway) http.Error(w, errstr, http.StatusBadGateway)
return return
} }
startDuration = time.Since(beginStartTime)
} }
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
@@ -422,4 +472,8 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
} }
+43 -14
View File
@@ -2,7 +2,6 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -13,13 +12,26 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", 5, config, logMonitor) process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
@@ -52,11 +64,10 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start() // are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) { func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, logMonitor) process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -84,7 +95,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("broken", 1, config, NewLogMonitor()) process := NewProcess("broken", 1, config, debugLogger, debugLogger)
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -109,7 +120,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard)) process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -151,7 +162,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter) assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout)) process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -178,7 +189,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345" expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout)) process := NewProcess("t", 10, config, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
results := map[string]string{ results := map[string]string{
@@ -255,9 +266,8 @@ func TestProcess_SwapState(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p := &Process{ p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
state: test.currentState, p.state = test.currentState
}
resultState, err := p.swapState(test.expectedState, test.newState) resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
@@ -282,7 +292,6 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test") t.Skip("skipping long shutdown test")
} }
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong // make a config where the healthcheck will always fail because port is wrong
@@ -290,7 +299,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test" config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor) process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
// make it a lot faster // make it a lot faster
process.healthCheckLoopInterval = time.Second process.healthCheckLoopInterval = time.Second
@@ -311,3 +320,23 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown") assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState()) assert.Equal(t, StateShutdown, process.CurrentState())
} }
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}
+113
View File
@@ -0,0 +1,113 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sync"
)
type ProcessGroup struct {
sync.Mutex
config Config
id string
swap bool
exclusive bool
persistent bool
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// map of current processes
processes map[string]*Process
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
}
pg := &ProcessGroup{
id: id,
config: config,
swap: groupConfig.Swap,
exclusive: groupConfig.Exclusive,
persistent: groupConfig.Persistent,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
processes: make(map[string]*Process),
}
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
return pg
}
// ProxyRequest proxies a request to the specified model
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
if !pg.HasMember(modelID) {
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
}
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
pg.lastUsedProcess = modelID
}
pg.Unlock()
}
pg.processes[modelID].ProxyRequest(writer, request)
return nil
}
func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) StopProcesses() {
pg.Lock()
defer pg.Unlock()
pg.stopProcesses()
}
// stopProcesses stops all processes in the group
func (pg *ProcessGroup) stopProcesses() {
if len(pg.processes) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
}
func (pg *ProcessGroup) Shutdown() {
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
+96
View File
@@ -0,0 +1,96 @@
package proxy
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model2"},
},
"G2": {
Swap: false,
Exclusive: true,
Members: []string{"model3", "model4"},
},
},
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
func TestProcessGroup_HasMember(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model1"))
assert.True(t, pg.HasMember("model2"))
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
tests := []string{"model1", "model2"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
tests := []string{"model3", "model4"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
})
}
// make sure all the processes are running
for _, process := range pg.processes {
assert.Equal(t, StateReady, process.CurrentState())
}
}
+169 -190
View File
@@ -7,6 +7,7 @@ import (
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -25,61 +26,97 @@ const (
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
config *Config config Config
currentProcesses map[string]*Process ginEngine *gin.Engine
logMonitor *LogMonitor
ginEngine *gin.Engine // logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
processGroups map[string]*ProcessGroup
} }
func New(config *Config) *ProxyManager { func New(config Config) *ProxyManager {
pm := &ProxyManager{ // set up loggers
config: config, stdoutLogger := NewLogMonitorWriter(os.Stdout)
currentProcesses: make(map[string]*Process), upstreamLogger := NewLogMonitorWriter(stdoutLogger)
logMonitor: NewLogMonitor(), proxyLogger := NewLogMonitorWriter(stdoutLogger)
ginEngine: gin.New(),
}
if config.LogRequests { if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) { proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
} }
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
processGroups: make(map[string]*ProcessGroup),
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues // see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// set this for all requests
c.Header("Access-Control-Allow-Origin", "*")
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default // allow whatever the client requested by default
@@ -118,6 +155,8 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
@@ -169,27 +208,17 @@ func (pm *ProxyManager) StopProcesses() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
pm.stopProcesses()
}
// for internal usage
func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
}
// stop Processes in parallel // stop Processes in parallel
var wg sync.WaitGroup var wg sync.WaitGroup
for _, process := range pm.currentProcesses { for _, processGroup := range pm.processGroups {
wg.Add(1) wg.Add(1)
go func(process *Process) { go func(processGroup *ProcessGroup) {
defer wg.Done() defer wg.Done()
process.Stop() processGroup.stopProcesses()
}(process) }(processGroup)
} }
wg.Wait()
pm.currentProcesses = make(map[string]*Process) wg.Wait()
} }
// Shutdown is called to shutdown all upstream processes // Shutdown is called to shutdown all upstream processes
@@ -198,18 +227,44 @@ func (pm *ProxyManager) Shutdown() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
// shutdown process in parallel pm.proxyLogger.Debug("Shutdown() called in proxy manager")
var wg sync.WaitGroup var wg sync.WaitGroup
for _, process := range pm.currentProcesses { // Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
wg.Add(1) wg.Add(1)
go func(process *Process) { go func(processGroup *ProcessGroup) {
defer wg.Done() defer wg.Done()
process.Shutdown() processGroup.Shutdown()
}(process) }(processGroup)
} }
wg.Wait() wg.Wait()
} }
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses()
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := []interface{}{}
for id, modelConfig := range pm.config.Models { for id, modelConfig := range pm.config.Models {
@@ -239,78 +294,6 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
} }
} }
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id") requestedModel := c.Param("model_id")
@@ -319,13 +302,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
if process, err := pm.swapModel(requestedModel); err != nil { processGroup, _, err := pm.swapProcessGroup(requestedModel)
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) if err != nil {
} else { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
// rewrite the path return
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
} }
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
} }
func (pm *ProxyManager) upstreamIndex(c *gin.Context) { func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
@@ -363,32 +348,23 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
requestedModel := gjson.GetBytes(bodyBytes, "model").String() requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" { if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
} }
process, err := pm.swapModel(requestedModel) processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
} }
// issue #69 allow custom model names to be sent to upstream // issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" { useModelName := pm.config.Models[realModelName].UseModelName
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName) if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return return
} }
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -397,8 +373,11 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Del("transfer-encoding") c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request) if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} }
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
@@ -420,26 +399,25 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return return
} }
// Swap to the requested model processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
process, err := pm.swapModel(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
} }
// Get profile name and model name from the requested model
profileName, modelName := splitRequestedModel(requestedModel)
// Copy all form values // Copy all form values
for key, values := range c.Request.MultipartForm.Value { for key, values := range c.Request.MultipartForm.Value {
for _, value := range values { for _, value := range values {
fieldValue := value fieldValue := value
// If this is the model field and we have a profile, use just the model name // If this is the model field and we have a profile, use just the model name
if key == "model" { if key == "model" {
if process.config.UseModelName != "" { // # issue #69 allow custom model names to be sent to upstream
fieldValue = process.config.UseModelName useModelName := pm.config.Models[realModelName].UseModelName
} else if profileName != "" {
fieldValue = modelName if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
} }
} }
field, err := multipartWriter.CreateFormField(key) field, err := multipartWriter.CreateFormField(key)
@@ -501,7 +479,11 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// Use the modified request for proxying // Use the modified request for proxying
process.ProxyRequest(c.Writer, modifiedReq) if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} }
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -523,14 +505,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json") context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response. runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, process := range pm.currentProcesses { for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
// Append the process ID and State (multiple entries if profiles are being used). if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{ runningProcesses = append(runningProcesses, gin.H{
"model": process.ID, "model": process.ID,
"state": process.state, "state": process.state,
}) })
}
}
} }
// Put the results under the `running` key. // Put the results under the `running` key.
@@ -541,15 +524,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.JSON(http.StatusOK, response) // Always return 200 OK context.JSON(http.StatusOK, response) // Always return 200 OK
} }
func ProcessKeyName(groupName, modelName string) string { func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
return groupName + PROFILE_SPLIT_CHAR + modelName for _, group := range pm.processGroups {
} if group.HasMember(modelName) {
return group
func splitRequestedModel(requestedModel string) (string, string) { }
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
} }
return profileName, modelName return nil
} }
+37 -8
View File
@@ -9,7 +9,6 @@ import (
) )
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") { if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
@@ -28,7 +27,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
} }
} else { } else {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
history := pm.logMonitor.GetHistory() history := pm.muxLogger.GetHistory()
_, err := c.Writer.Write(history) _, err := c.Writer.Write(history)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
@@ -42,8 +41,14 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked") c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe() logMonitorId := c.Param("logMonitorID")
defer pm.logMonitor.Unsubscribe(ch) logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -56,7 +61,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped // Send history first if not skipped
if !skipHistory { if !skipHistory {
history := pm.logMonitor.GetHistory() history := logger.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.Writer.Write(history) c.Writer.Write(history)
flusher.Flush() flusher.Flush()
@@ -85,15 +90,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe() logMonitorId := c.Param("logMonitorID")
defer pm.logMonitor.Unsubscribe(ch) logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
// Send history first if not skipped // Send history first if not skipped
_, skipHistory := c.GetQuery("no-history") _, skipHistory := c.GetQuery("no-history")
if !skipHistory { if !skipHistory {
history := pm.logMonitor.GetHistory() history := logger.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.SSEvent("message", string(history)) c.SSEvent("message", string(history))
c.Writer.Flush() c.Writer.Flush()
@@ -111,3 +122,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
} }
} }
} }
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+183 -322
View File
@@ -16,13 +16,14 @@ import (
) )
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
} LogLevel: "error",
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
@@ -35,58 +36,91 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
} }
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
} }
func TestProxyManager_SwapMultiProcess(t *testing.T) { func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
Profiles: map[string][]string{ LogLevel: "error",
"test": {model1, model2}, Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: false,
Members: []string{"model1"},
},
"G2": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
}, },
} })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{ tests := []string{"model1", "model2"}
"model1": profileModel1, for _, requestedModel := range tests {
"model2": profileModel2, t.Run(requestedModel, func(t *testing.T) {
} { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel)
})
}
// make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
}
// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model2"},
},
},
})
proxy := New(config)
defer proxy.StopProcesses()
// make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"}
for _, requestedModel := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID) assert.Contains(t, w.Body.String(), requestedModel)
} }
// make sure there's two loaded models assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
assert.Len(t, proxy.currentProcesses, 2) assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
} }
// When a request for a different model comes in ProxyManager should wait until // When a request for a different model comes in ProxyManager should wait until
@@ -96,14 +130,15 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test") t.Skip("skipping slow test")
} }
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
} LogLevel: "error",
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
@@ -146,13 +181,14 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
} }
func TestProxyManager_ListModelsHandler(t *testing.T) { func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{ config := Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error",
} }
proxy := New(config) proxy := New(config)
@@ -213,50 +249,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
assert.Empty(t, expectedModels, "not all expected models were returned") assert.Empty(t, expectedModels, "not all expected models were returned")
} }
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -268,23 +260,27 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/" model3Config.Proxy = "http://localhost:10003/"
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": model3Config, "model3": model3Config,
}, },
} LogLevel: "error",
Groups: map[string]GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
},
},
})
proxy := New(config) proxy := New(config)
// Start all the processes // Start all the processes
var wg sync.WaitGroup var wg sync.WaitGroup
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} { for _, modelName := range []string{"model1", "model2", "model3"} {
wg.Add(1) wg.Add(1)
go func(modelName string) { go func(modelName string) {
defer wg.Done() defer wg.Done()
@@ -292,11 +288,10 @@ func TestProxyManager_Shutdown(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
// send a request to trigger the proxy to load // send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName) }(modelName)
} }
@@ -308,64 +303,44 @@ func TestProxyManager_Shutdown(t *testing.T) {
} }
func TestProxyManager_Unload(t *testing.T) { func TestProxyManager_Unload(t *testing.T) {
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
} LogLevel: "error",
})
proxy := New(config) proxy := New(config)
proc, err := proxy.swapModel("model1") reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
assert.NoError(t, err) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
assert.NotNil(t, proc)
assert.Len(t, proxy.currentProcesses, 1)
req := httptest.NewRequest("GET", "/unload", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name // give it a bit of time to stop
func TestProxyManager_StripProfileSlug(t *testing.T) { <-time.After(time.Millisecond * 250)
config := &Config{ assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
} }
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
Profiles: map[string][]string{ LogLevel: "debug",
"test": {"model1", "model2"}, })
},
}
// Define a helper struct to parse the JSON response. // Define a helper struct to parse the JSON response.
type RunningResponse struct { type RunningResponse struct {
@@ -420,235 +395,126 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Is the model loaded? // Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State) assert.Equal(t, "ready", response.Running[0].State)
}) })
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
} }
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
}, },
} LogLevel: "error",
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
testCases := []struct { // Create a buffer with multipart form data
name string var b bytes.Buffer
modelInput string w := multipart.NewWriter(&b)
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
for _, tc := range testCases { // Add the model field
t.Run(tc.name, func(t *testing.T) { fw, err := w.CreateFormField("model")
// Create a buffer with multipart form data assert.NoError(t, err)
var b bytes.Buffer _, err = fw.Write([]byte("TheExpectedModel"))
w := multipart.NewWriter(&b) assert.NoError(t, err)
// Add the model field // Add a file field
fw, err := w.CreateFormField("model") fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err) assert.NoError(t, err)
_, err = fw.Write([]byte(tc.modelInput)) // Generate random content length between 10 and 20
assert.NoError(t, err) contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Add a file field // Create the request with the multipart form data
fw, err = w.CreateFormFile("file", "test.mp3") req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
assert.NoError(t, err) req.Header.Set("Content-Type", w.FormDataContentType())
// Generate random content length between 10 and 20 rec := httptest.NewRecorder()
contentLength := rand.Intn(11) + 10 // 10 to 20 proxy.HandlerFunc(rec, req)
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data // Verify the response
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) assert.Equal(t, http.StatusOK, rec.Code)
req.Header.Set("Content-Type", w.FormDataContentType()) var response map[string]string
rec := httptest.NewRecorder() err = json.Unmarshal(rec.Body.Bytes(), &response)
proxy.HandlerFunc(rec, req) assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"])
// Verify the response assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
} }
// Test useModelName in configuration sends overrides what is sent to upstream // Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) { func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel" upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName) modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName modelConfig.UseModelName = upstreamModelName
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": modelConfig, "model1": modelConfig,
}, },
} LogLevel: "error",
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
tests := []struct { requestedModel := "model1"
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests { t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder()
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName) assert.Contains(t, w.Body.String(), upstreamModelName)
})
}) t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
} // Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
for _, tt := range tests { // Add the model field
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) { fw, err := w.CreateFormField("model")
// Create a buffer with multipart form data assert.NoError(t, err)
var b bytes.Buffer _, err = fw.Write([]byte(requestedModel))
w := multipart.NewWriter(&b) assert.NoError(t, err)
// Add the model field // Add a file field
fw, err := w.CreateFormField("model") fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err) assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel)) _, err = fw.Write([]byte("test"))
assert.NoError(t, err) assert.NoError(t, err)
w.Close()
// Add a file field // Create the request with the multipart form data
fw, err = w.CreateFormFile("file", "test.mp3") req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
assert.NoError(t, err) req.Header.Set("Content-Type", w.FormDataContentType())
_, err = fw.Write([]byte("test")) rec := httptest.NewRecorder()
assert.NoError(t, err) proxy.HandlerFunc(rec, req)
w.Close()
// Create the request with the multipart form data // Verify the response
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) assert.Equal(t, http.StatusOK, rec.Code)
req.Header.Set("Content-Type", w.FormDataContentType()) var response map[string]string
rec := httptest.NewRecorder() err = json.Unmarshal(rec.Body.Bytes(), &response)
proxy.HandlerFunc(rec, req) assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
// Verify the response })
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
} }
func TestProxyManager_CORSOptionsHandler(t *testing.T) { func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogRequests: true, LogLevel: "error",
} })
tests := []struct { tests := []struct {
name string name string
@@ -709,25 +575,20 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
} }
} }
func TestProxyManager_CORSHeadersInRegularRequest(t *testing.T) { func TestProxyManager_Upstream(t *testing.T) {
config := &Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogRequests: true, LogLevel: "error",
} })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
// Test that CORS headers are present in regular POST requests rec := httptest.NewRecorder()
reqBody := `{"model":"model1"}` proxy.HandlerFunc(rec, req)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) assert.Equal(t, http.StatusOK, rec.Code)
w := httptest.NewRecorder() assert.Equal(t, "model1", rec.Body.String())
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
} }