Compare commits

..

2 Commits

Author SHA1 Message Date
Benson Wong 014a2fa9a3 fix bug checking incorrect error 2025-03-20 15:26:39 -07:00
Benson Wong 5ceaef6144 add override for windows 2025-03-20 13:21:03 -07:00
29 changed files with 674 additions and 1998 deletions
+3 -3
View File
@@ -13,11 +13,11 @@ jobs:
steps: steps:
- uses: actions/stale@v9 - uses: actions/stale@v9
with: with:
days-before-issue-stale: 14 days-before-issue-stale: 30
days-before-issue-close: 14 days-before-issue-close: 14
stale-issue-label: "stale" stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity." stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale." close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
+1 -2
View File
@@ -15,8 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
#platform: [intel, cuda, vulkan, cpu, musa] platform: [intel, cuda, vulkan, cpu, musa]
platform: [cuda, vulkan, cpu, musa]
fail-fast: false fail-fast: false
steps: steps:
- name: Checkout code - name: Checkout code
+22 -78
View File
@@ -1,7 +1,4 @@
![llama-swap header image](header2.png) ![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
@@ -26,7 +23,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58)) - `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - ✅ Docker and Podman support
@@ -36,7 +33,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## config.yaml ## config.yaml
@@ -67,16 +64,8 @@ models:
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Valid log levels: debug, info (default), warn, error # Write HTTP logs (useful for troubleshooting), defaults to false
logLevel: info logRequests: true
# Automatic Port Values
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
# when you use ${PORT} you can omit a custom model.proxy value, as it will
# default to http://localhost:${PORT}
# override the default port (5800) for automatic port values
startPort: 10001
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
@@ -91,7 +80,6 @@ models:
- "CUDA_VISIBLE_DEVICES=0" - "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match # where to reach the server started by cmd, make sure the ports match
# can be omitted if you use an automatic ${PORT} in cmd
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# aliases names to use this model for # aliases names to use this model for
@@ -118,69 +106,27 @@ models:
# but they can still be requested as normal # but they can still be requested as normal
"qwen-unlisted": "qwen-unlisted":
unlisted: true unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0 cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (v26.1.4+ required!) # Docker Support (v26.1.4+ required!)
"docker-llama": "docker-llama":
proxy: "http://127.0.0.1:${PORT}" proxy: "http://127.0.0.1:9790"
cmd: > cmd: >
docker run --name dockertest docker run --name dockertest
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models --init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# Groups provide advanced controls over model swapping behaviour. Using groups # profiles eliminates swapping by running multiple models at the same time
# some models can be kept loaded indefinitely, while others are swapped out.
# #
# Tips: # Tips:
# # - each model must be listening on a unique address and port
# - models must be defined above in the Models section # - the model name is in this format: "profile_name:model", like "coding:qwen"
# - a model can only be a member of one group # - the profile will load and unload all models in the profile at the same time
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields profiles:
# - see issue #109 for details coding:
# - "llama"
# NOTE: the example below uses model names that are not defined above for demonstration purposes - "qwen-unlisted"
groups:
# group1 is the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap controls the model swapping behaviour in within the group
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive controls how the group affects other groups
# - true: causes all other groups to unload their models when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
members:
- "llama"
- "qwen-unlisted"
# models in this group are never unloaded
"group2":
swap: false
exclusive: false
members:
- "docker-llama"
# (not defined above, here for example)
- "modelA"
- "modelB"
"forever":
# setting persistent to true causes the group to never be affected by the swapping behaviour of
# other groups. It is a shortcut to keeping some models always loaded.
persistent: true
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
``` ```
### Use Case Examples ### Use Case Examples
@@ -270,15 +216,9 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
# streams combined logs # streams logs
curl -Ns 'http://host/logs/stream' curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
@@ -320,4 +260,8 @@ WantedBy=multi-user.target
## Star History ## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date) <picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests # Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 90 healthCheckTimeout: 15
# valid log levels: debug, info (default), warn, error # Log HTTP requests helpful for troubleshoot, defaults to False
logLevel: debug logRequests: true
models: models:
"llama": "llama":
-153
View File
@@ -1,153 +0,0 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -1,28 +0,0 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -1,26 +0,0 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
-49
View File
@@ -1,49 +0,0 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
+1 -1
View File
@@ -37,7 +37,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
-2
View File
@@ -86,8 +86,6 @@ golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 351 KiB

-4
View File
@@ -34,10 +34,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if len(config.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
if mode := os.Getenv("GIN_MODE"); mode != "" { if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode) gin.SetMode(mode)
} else { } else {
+4 -14
View File
@@ -33,17 +33,14 @@ func main() {
// Set up the handler function using the provided response message // Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) { r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "text/plain")
// add a wait to simulate a slow query // add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil { if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait) time.Sleep(wait)
} }
c.JSON(http.StatusOK, gin.H{ c.String(200, *responseMessage)
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
})
}) })
// for issue #62 to check model name strips profile slug // for issue #62 to check model name strips profile slug
@@ -66,11 +63,8 @@ func main() {
}) })
r.POST("/v1/completions", func(c *gin.Context) { r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "text/plain")
c.JSON(http.StatusOK, gin.H{ c.String(200, *responseMessage)
"responseMessage": *responseMessage,
})
}) })
// issue #41 // issue #41
@@ -110,10 +104,6 @@ func main() {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize), "text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model, "model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
}) })
}) })
+6 -151
View File
@@ -2,18 +2,13 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"sort"
"strconv"
"strings" "strings"
"github.com/google/shlex" "github.com/google/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct { type ModelConfig struct {
Cmd string `yaml:"cmd"` Cmd string `yaml:"cmd"`
Proxy string `yaml:"proxy"` Proxy string `yaml:"proxy"`
@@ -29,44 +24,14 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd) return SanitizeCommand(m.Cmd)
} }
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"` Models map[string]ModelConfig `yaml:"models"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// map aliases to actual model IDs // map aliases to actual model IDs
aliases map[string]string aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
} }
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
@@ -87,141 +52,31 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
} }
} }
func LoadConfig(path string) (Config, error) { func LoadConfig(path string) (*Config, error) {
file, err := os.Open(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return Config{}, err return nil, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
} }
var config Config var config Config
err = yaml.Unmarshal(data, &config) err = yaml.Unmarshal(data, &config)
if err != nil { if err != nil {
return Config{}, err return nil, err
} }
if config.HealthCheckTimeout < 15 { if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15 config.HealthCheckTimeout = 15
} }
// set default port ranges
if config.StartPort == 0 {
// default to 5800
config.StartPort = 5800
} else if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map // Populate the aliases map
config.aliases = make(map[string]string) config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models { for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases { for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName config.aliases[alias] = modelName
} }
} }
// iterate over the models and replace any ${PORT} with the next available port return &config, nil
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
// iterate over the sorted models
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
if strings.Contains(modelConfig.Cmd, "${PORT}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
if modelConfig.Proxy == "" {
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
} else {
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
}
nextPort++
config.Models[modelId] = modelConfig
} else if modelConfig.Proxy == "" {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
} }
func SanitizeCommand(cmdStr string) ([]string, error) { func SanitizeCommand(cmdStr string) ([]string, error) {
+1 -186
View File
@@ -3,7 +3,6 @@ package proxy
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -36,32 +35,11 @@ models:
aliases: aliases:
- "m2" - "m2"
checkEndpoint: "/" checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8082"
checkEndpoint: "/"
healthCheckTimeout: 15 healthCheckTimeout: 15
profiles: profiles:
test: test:
- model1 - model1
- model2 - model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
` `
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -74,8 +52,7 @@ groups:
t.Fatalf("Failed to load config: %v", err) t.Fatalf("Failed to load config: %v", err)
} }
expected := Config{ expected := &Config{
StartPort: 5800,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -91,18 +68,6 @@ groups:
Env: nil, Env: nil,
CheckEndpoint: "/", CheckEndpoint: "/",
}, },
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: nil,
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
},
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{ Profiles: map[string][]string{
@@ -112,25 +77,6 @@ groups:
"m1": "model1", "m1": "model1",
"model-one": "model1", "model-one": "model1",
"m2": "model2", "m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
}, },
} }
@@ -141,63 +87,6 @@ groups:
assert.Equal(t, "model1", realname) assert.Equal(t, "model1", realname)
} }
func TestConfig_GroupMemberIsUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// a Contains as order of the map is not guaranteed
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
}
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- m1
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
aliases:
- m1
- m2
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// this is a contains because it could be `model1` or `model2` depending on the order
// go decided on the order of the map
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{ config := &ModelConfig{
Cmd: `python model1.py \ Cmd: `python model1.py \
@@ -285,77 +174,3 @@ func TestConfig_SanitizeCommand(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, args) assert.Nil(t, args)
} }
func TestConfig_AutomaticPortAssignments(t *testing.T) {
t.Run("Default Port Ranges", func(t *testing.T) {
content := ``
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
})
t.Run("User specific port ranges", func(t *testing.T) {
content := `startPort: 1000`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 1000, config.StartPort)
})
t.Run("Invalid start port", func(t *testing.T) {
content := `startPort: abcd`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("start port must be greater than 1", func(t *testing.T) {
content := `startPort: -99`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("Automatic port assignments", func(t *testing.T) {
content := `
startPort: 5800
models:
model1:
cmd: svr --port ${PORT}
model2:
cmd: svr --port ${PORT}
proxy: "http://172.11.22.33:${PORT}"
model3:
cmd: svr --port 1999
proxy: "http://1.2.3.4:1999"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
})
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
content := `
models:
model1:
cmd: svr --port 111
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
})
}
-12
View File
@@ -14,7 +14,6 @@ import (
var ( var (
nextTestPort int = 12000 nextTestPort int = 12000
portMutex sync.Mutex portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout)
) )
// Check if the binary exists // Check if the binary exists
@@ -27,17 +26,6 @@ func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
switch os.Getenv("LOG_LEVEL") {
case "debug":
testLogger.SetLogLevel(LevelDebug)
case "warn":
testLogger.SetLogLevel(LevelWarn)
case "info":
testLogger.SetLogLevel(LevelInfo)
default:
testLogger.SetLogLevel(LevelWarn)
}
m.Run() m.Run()
} }
+78 -192
View File
@@ -12,65 +12,32 @@
flex-direction: column; flex-direction: column;
font-family: "Courier New", Courier, monospace; font-family: "Courier New", Courier, monospace;
} }
.log-container { #log-controls {
display: flex;
flex: 1;
gap: 0.5em;
margin: 0.5em; margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex; display: flex;
flex-direction: column; align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1; flex: 1;
min-width: 0;
transition: flex 0.3s ease;
} }
.log-column.minimized { #log-controls input:focus {
flex: 0.1; outline: none; /* Ensures no outline is shown when the input is focused */
max-width: 50px;
border: 1px solid #777;
color: green;
} }
.log-controls { #log-stream {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1; flex: 1;
margin: 0.5em;
padding: 1em; padding: 1em;
background: #f4f4f4; background: #f4f4f4;
overflow-y: auto; overflow-y: auto;
white-space: pre-wrap; white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; word-wrap: break-word; /* Ensures long words wrap */
min-height: 0;
} }
.regex-error { .regex-error {
background-color: #ff0000 !important; background-color: #ff0000 !important;
} }
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */ /* Dark mode styles */
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
body { body {
@@ -78,182 +45,101 @@
color: #fff; color: #fff;
} }
.log-stream { #log-stream {
background: #444; background: #444;
color: #fff; color: #fff;
} }
.log-controls input { #log-controls input {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
.log-controls button { #log-controls button {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
} }
</style> </style>
</head> </head>
<body> <body>
<div class="log-container"> <pre id="log-stream">Waiting for logs...</pre>
<div class="log-column"> <div id="log-controls">
<h2>Proxy Logs</h2> <input type="text" id="filter-input" placeholder="regex filter">
<div class="log-controls"> <button id="clear-button">clear</button>
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div> </div>
<script> <script>
class LogStream { const logStream = document.getElementById('log-stream');
constructor(streamElement, filterInput, clearButton, endpoint) { const filterInput = document.getElementById('filter-input');
this.streamElement = streamElement; var logData = "";
this.filterInput = filterInput; let regexFilter = null;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
this.initialize(); function setupEventSource() {
} if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
initialize() { eventSource.onmessage = function(event) {
this.filterInput.addEventListener('input', () => this.updateFilter()); logData += event.data;
this.clearButton.addEventListener('click', () => { render()
this.filterInput.value = "";
this.regexFilter = null;
this.render();
});
this.setupEventSource();
}
setupEventSource() {
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.logData = this.logData.slice(-1024 * 100);
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
}; };
// Initial connection eventSource.onerror = function(err) {
connect(); logData = "EventSource failed: " + err.message;
} };
} else {
render() { logData = "SSE Not supported by this browser."
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
}
try {
this.regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
this.regexFilter = null;
this.filterInput.classList.add('regex-error');
return;
}
this.render();
} }
} }
// Initialize both log streams // poor-ai's react ¯\_(ツ)_/¯
document.addEventListener('DOMContentLoaded', () => { function render() {
new LogStream( if (regexFilter) {
document.getElementById('proxy-log-stream'), const lines = logData.split('\n');
document.getElementById('proxy-filter-input'), const filteredLines = lines.filter(line => {
document.getElementById('proxy-clear-button'), return regexFilter === null || regexFilter.test(line);
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
}); });
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
}
logStream.scrollTop = logStream.scrollHeight;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
}
} else {
regexFilter = null;
}
render();
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
}); });
setupEventSource();
updateFilter();
</script> </script>
</body> </body>
</html> </html>
-90
View File
@@ -2,21 +2,11 @@ package proxy
import ( import (
"container/ring" "container/ring"
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
) )
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool clients map[chan []byte]bool
mu sync.RWMutex mu sync.RWMutex
@@ -25,10 +15,6 @@ type LogMonitor struct {
// typically this can be os.Stdout // typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels
level LogLevel
prefix string
} }
func NewLogMonitor() *LogMonitor { func NewLogMonitor() *LogMonitor {
@@ -40,8 +26,6 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool), clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo,
prefix: "",
} }
} }
@@ -110,77 +94,3 @@ func (w *LogMonitor) broadcast(msg []byte) {
} }
} }
} }
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+29 -92
View File
@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -31,15 +30,10 @@ const (
) )
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor
// for p.cmd.Wait() select { ... }
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration healthCheckLoopInterval time.Duration
@@ -59,15 +53,13 @@ type Process struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
cmdWaitChan: make(chan error, 1), logMonitor: logMonitor,
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped, state: StateStopped,
@@ -76,11 +68,6 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
} }
} }
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state // custom error types for swapping state
var ( var (
ErrExpectedStateMismatch = errors.New("expected state mismatch") ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -94,17 +81,14 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
defer p.stateMutex.Unlock() defer p.stateMutex.Unlock()
if p.state != expectedState { if p.state != expectedState {
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
return p.state, ErrExpectedStateMismatch return p.state, ErrExpectedStateMismatch
} }
if !isValidTransition(p.state, newState) { if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
return p.state, ErrInvalidStateTransition return p.state, ErrInvalidStateTransition
} }
p.state = newState p.state = newState
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
return p.state, nil return p.state, nil
} }
@@ -168,8 +152,8 @@ func (p *Process) start() error {
defer p.waitStarting.Done() defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.processLogger p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.processLogger p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
@@ -185,13 +169,6 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
// Capture the exit error for later signaling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage: // One of three things can happen at this stage:
// 1. The command exits unexpectedly // 1. The command exits unexpectedly
// 2. The health check fails // 2. The health check fails
@@ -235,34 +212,17 @@ func (p *Process) start() error {
} }
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely but successfully")
}
}
default: default:
if err := p.checkHealthEndpoint(healthURL); err == nil { if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
cancelHealthCheck() cancelHealthCheck()
break loop break loop
} else { } else {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds()) fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else { } else {
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err) fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
} }
} }
} }
@@ -286,7 +246,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
} }
@@ -302,17 +262,12 @@ func (p *Process) start() error {
} }
func (p *Process) Stop() { func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
return return
} }
@@ -320,7 +275,7 @@ func (p *Process) Stop() {
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
} }
} }
@@ -336,51 +291,47 @@ func (p *Process) Shutdown() {
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL. // If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) { func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout() defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { p.cmd.Process.Signal(syscall.SIGTERM)
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err)
}
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill() p.cmd.Process.Kill()
case err := <-p.cmdWaitChan: case err := <-sigtermNormal:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else { } else {
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
} }
} else { } else {
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err) fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
} }
} }
} }
} }
func (p *Process) checkHealthEndpoint(healthURL string) error { func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{ client := &http.Client{
Timeout: 500 * time.Millisecond, Timeout: 500 * time.Millisecond,
} }
@@ -405,8 +356,6 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable // prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState() currentState := p.CurrentState()
@@ -423,13 +372,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand // start the process on demand
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway) http.Error(w, errstr, http.StatusBadGateway)
return return
} }
startDuration = time.Since(beginStartTime)
} }
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
@@ -440,12 +387,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
req.Header = r.Header.Clone() req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway) http.Error(w, err.Error(), http.StatusBadGateway)
@@ -479,8 +420,4 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
} }
-9
View File
@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
-14
View File
@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+14 -43
View File
@@ -2,6 +2,7 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -12,26 +13,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", 5, config, debugLogger, debugLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
@@ -64,10 +52,11 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start() // are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) { func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -95,7 +84,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("broken", 1, config, debugLogger, debugLogger) process := NewProcess("broken", 1, config, NewLogMonitor())
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -120,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -162,7 +151,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter) assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -189,7 +178,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345" expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, debugLogger, debugLogger) process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
results := map[string]string{ results := map[string]string{
@@ -266,8 +255,9 @@ func TestProcess_SwapState(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) p := &Process{
p.state = test.currentState state: test.currentState,
}
resultState, err := p.swapState(test.expectedState, test.newState) resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
@@ -292,6 +282,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test") t.Skip("skipping long shutdown test")
} }
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong // make a config where the healthcheck will always fail because port is wrong
@@ -299,7 +290,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test" config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// make it a lot faster // make it a lot faster
process.healthCheckLoopInterval = time.Second process.healthCheckLoopInterval = time.Second
@@ -320,23 +311,3 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown") assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState()) assert.Equal(t, StateShutdown, process.CurrentState())
} }
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}
-113
View File
@@ -1,113 +0,0 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sync"
)
type ProcessGroup struct {
sync.Mutex
config Config
id string
swap bool
exclusive bool
persistent bool
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// map of current processes
processes map[string]*Process
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
}
pg := &ProcessGroup{
id: id,
config: config,
swap: groupConfig.Swap,
exclusive: groupConfig.Exclusive,
persistent: groupConfig.Persistent,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
processes: make(map[string]*Process),
}
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
return pg
}
// ProxyRequest proxies a request to the specified model
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
if !pg.HasMember(modelID) {
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
}
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
pg.lastUsedProcess = modelID
}
pg.Unlock()
}
pg.processes[modelID].ProxyRequest(writer, request)
return nil
}
func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) StopProcesses() {
pg.Lock()
defer pg.Unlock()
pg.stopProcesses()
}
// stopProcesses stops all processes in the group
func (pg *ProcessGroup) stopProcesses() {
if len(pg.processes) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
}
func (pg *ProcessGroup) Shutdown() {
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
-96
View File
@@ -1,96 +0,0 @@
package proxy
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model2"},
},
"G2": {
Swap: false,
Exclusive: true,
Members: []string{"model3", "model4"},
},
},
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
func TestProcessGroup_HasMember(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model1"))
assert.True(t, pg.HasMember("model2"))
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
tests := []string{"model1", "model2"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
tests := []string{"model3", "model4"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
})
}
// make sure all the processes are running
for _, process := range pg.processes {
assert.Equal(t, StateReady, process.CurrentState())
}
}
+198 -195
View File
@@ -7,7 +7,6 @@ import (
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -26,111 +25,61 @@ const (
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
config Config config *Config
ginEngine *gin.Engine currentProcesses map[string]*Process
logMonitor *LogMonitor
// logging ginEngine *gin.Engine
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
processGroups map[string]*ProcessGroup
} }
func New(config Config) *ProxyManager { func New(config *Config) *ProxyManager {
// set up loggers pm := &ProxyManager{
stdoutLogger := NewLogMonitorWriter(os.Stdout) config: config,
upstreamLogger := NewLogMonitorWriter(stdoutLogger) currentProcesses: make(map[string]*Process),
proxyLogger := NewLogMonitorWriter(stdoutLogger) logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
}
if config.LogRequests { if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
} }
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { // see: https://github.com/mostlygeek/llama-swap/issues/42
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
processGroups: make(map[string]*ProcessGroup),
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
// allow whatever the client requested by default c.AbortWithStatus(204)
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return return
} }
c.Next() c.Next()
@@ -155,8 +104,6 @@ func New(config Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
@@ -208,17 +155,27 @@ func (pm *ProxyManager) StopProcesses() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
// stop Processes in parallel pm.stopProcesses()
var wg sync.WaitGroup }
for _, processGroup := range pm.processGroups {
wg.Add(1) // for internal usage
go func(processGroup *ProcessGroup) { func (pm *ProxyManager) stopProcesses() {
defer wg.Done() if len(pm.currentProcesses) == 0 {
processGroup.stopProcesses() return
}(processGroup)
} }
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait() wg.Wait()
pm.currentProcesses = make(map[string]*Process)
} }
// Shutdown is called to shutdown all upstream processes // Shutdown is called to shutdown all upstream processes
@@ -227,44 +184,18 @@ func (pm *ProxyManager) Shutdown() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
pm.proxyLogger.Debug("Shutdown() called in proxy manager") // shutdown process in parallel
var wg sync.WaitGroup var wg sync.WaitGroup
// Send shutdown signal to all process in groups for _, process := range pm.currentProcesses {
for _, processGroup := range pm.processGroups {
wg.Add(1) wg.Add(1)
go func(processGroup *ProcessGroup) { go func(process *Process) {
defer wg.Done() defer wg.Done()
processGroup.Shutdown() process.Shutdown()
}(processGroup) }(process)
} }
wg.Wait() wg.Wait()
} }
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses()
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := []interface{}{}
for id, modelConfig := range pm.config.Models { for id, modelConfig := range pm.config.Models {
@@ -294,6 +225,78 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
} }
} }
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id") requestedModel := c.Param("model_id")
@@ -302,15 +305,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
processGroup, _, err := pm.swapProcessGroup(requestedModel) if process, err := pm.swapModel(requestedModel); err != nil {
if err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } else {
return // rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
} }
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
} }
func (pm *ProxyManager) upstreamIndex(c *gin.Context) { func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
@@ -348,23 +349,32 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
requestedModel := gjson.GetBytes(bodyBytes, "model").String() requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" { if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) process, err := pm.swapModel(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} }
// issue #69 allow custom model names to be sent to upstream // issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName if process.config.UseModelName != "" {
if useModelName != "" { bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return return
} }
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -373,14 +383,16 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Del("transfer-encoding") c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { process.ProxyRequest(c.Writer, c.Request)
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} }
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form // Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -394,16 +406,15 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) // Swap to the requested model
process, err := pm.swapModel(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} }
// We need to reconstruct the multipart form in any case since the body is consumed // Get profile name and model name from the requested model
// Create a new buffer for the reconstructed request profileName, modelName := splitRequestedModel(requestedModel)
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values // Copy all form values
for key, values := range c.Request.MultipartForm.Value { for key, values := range c.Request.MultipartForm.Value {
@@ -411,13 +422,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
fieldValue := value fieldValue := value
// If this is the model field and we have a profile, use just the model name // If this is the model field and we have a profile, use just the model name
if key == "model" { if key == "model" {
// # issue #69 allow custom model names to be sent to upstream if process.config.UseModelName != "" {
useModelName := pm.config.Models[realModelName].UseModelName fieldValue = process.config.UseModelName
} else if profileName != "" {
if useModelName != "" { fieldValue = modelName
fieldValue = useModelName
} else {
fieldValue = requestedModel
} }
} }
field, err := multipartWriter.CreateFormField(key) field, err := multipartWriter.CreateFormField(key)
@@ -478,16 +486,8 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header = c.Request.Header.Clone() modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying // Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { process.ProxyRequest(c.Writer, modifiedReq)
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} }
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -509,15 +509,14 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json") context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response. runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups { for _, process := range pm.currentProcesses {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady { // Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{ runningProcesses = append(runningProcesses, gin.H{
"model": process.ID, "model": process.ID,
"state": process.state, "state": process.state,
}) })
}
}
} }
// Put the results under the `running` key. // Put the results under the `running` key.
@@ -528,11 +527,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.JSON(http.StatusOK, response) // Always return 200 OK context.JSON(http.StatusOK, response) // Always return 200 OK
} }
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup { func ProcessKeyName(groupName, modelName string) string {
for _, group := range pm.processGroups { return groupName + PROFILE_SPLIT_CHAR + modelName
if group.HasMember(modelName) { }
return group
} func splitRequestedModel(requestedModel string) (string, string) {
} profileName, modelName := "", requestedModel
return nil if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
} }
+8 -37
View File
@@ -9,6 +9,7 @@ import (
) )
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") { if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
@@ -27,7 +28,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
} }
} else { } else {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
history := pm.muxLogger.GetHistory() history := pm.logMonitor.GetHistory()
_, err := c.Writer.Write(history) _, err := c.Writer.Write(history)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
@@ -41,14 +42,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked") c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -61,7 +56,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped // Send history first if not skipped
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.Writer.Write(history) c.Writer.Write(history)
flusher.Flush() flusher.Flush()
@@ -90,21 +85,15 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
// Send history first if not skipped // Send history first if not skipped
_, skipHistory := c.GetQuery("no-history") _, skipHistory := c.GetQuery("no-history")
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.SSEvent("message", string(history)) c.SSEvent("message", string(history))
c.Writer.Flush() c.Writer.Flush()
@@ -122,21 +111,3 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
} }
} }
} }
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+305 -284
View File
@@ -8,7 +8,6 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -17,14 +16,13 @@ import (
) )
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
@@ -37,91 +35,58 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
} }
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
} }
func TestProxyManager_SwapMultiProcess(t *testing.T) { func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), model1: getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), model2: getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", Profiles: map[string][]string{
Groups: map[string]GroupConfig{ "test": {model1, model2},
"G1": {
Swap: true,
Exclusive: false,
Members: []string{"model1"},
},
"G2": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
}, },
})
proxy := New(config)
defer proxy.StopProcesses()
tests := []string{"model1", "model2"}
for _, requestedModel := range tests {
t.Run(requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel)
})
} }
// make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
}
// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model2"},
},
},
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
// make requests to load all models, loading model1 should not affect model2 for modelID, requestedModel := range map[string]string{
tests := []string{"model2", "model1"} "model1": profileModel1,
for _, requestedModel := range tests { "model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), modelID)
} }
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) // make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
} }
// When a request for a different model comes in ProxyManager should wait until // When a request for a different model comes in ProxyManager should wait until
@@ -131,15 +96,14 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test") t.Skip("skipping slow test")
} }
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
@@ -166,9 +130,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
mu.Lock() mu.Lock()
var response map[string]string results[key] = w.Body.String()
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"]
mu.Unlock() mu.Unlock()
}(key) }(key)
@@ -184,14 +146,13 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
} }
func TestProxyManager_ListModelsHandler(t *testing.T) { func TestProxyManager_ListModelsHandler(t *testing.T) {
config := Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error",
} }
proxy := New(config) proxy := New(config)
@@ -252,6 +213,50 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
assert.Empty(t, expectedModels, "not all expected models were returned") assert.Empty(t, expectedModels, "not all expected models were returned")
} }
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -263,27 +268,23 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/" model3Config.Proxy = "http://localhost:10003/"
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": model3Config, "model3": model3Config,
}, },
LogLevel: "error", }
Groups: map[string]GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
},
},
})
proxy := New(config) proxy := New(config)
// Start all the processes // Start all the processes
var wg sync.WaitGroup var wg sync.WaitGroup
for _, modelName := range []string{"model1", "model2", "model3"} { for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
wg.Add(1) wg.Add(1)
go func(modelName string) { go func(modelName string) {
defer wg.Done() defer wg.Done()
@@ -291,10 +292,11 @@ func TestProxyManager_Shutdown(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up // send a request to trigger the proxy to load
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName) }(modelName)
} }
@@ -306,44 +308,64 @@ func TestProxyManager_Shutdown(t *testing.T) {
} }
func TestProxyManager_Unload(t *testing.T) { func TestProxyManager_Unload(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") proc, err := proxy.swapModel("model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) assert.NoError(t, err)
w := httptest.NewRecorder() assert.NotNil(t, proc)
proxy.HandlerFunc(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) assert.Len(t, proxy.currentProcesses, 1)
req = httptest.NewRequest("GET", "/unload", nil) req := httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// give it a bit of time to stop // issue 62, strip profile slug from model name
<-time.After(time.Millisecond * 250) func TestProxyManager_StripProfileSlug(t *testing.T) {
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
} }
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "debug", Profiles: map[string][]string{
}) "test": {"model1", "model2"},
},
}
// Define a helper struct to parse the JSON response. // Define a helper struct to parse the JSON response.
type RunningResponse struct { type RunningResponse struct {
@@ -398,225 +420,224 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Is the model loaded? // Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State) assert.Equal(t, "ready", response.Running[0].State)
}) })
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
} }
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
// Create a buffer with multipart form data testCases := []struct {
var b bytes.Buffer name string
w := multipart.NewWriter(&b) modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
// Add the model field for _, tc := range testCases {
fw, err := w.CreateFormField("model") t.Run(tc.name, func(t *testing.T) {
assert.NoError(t, err) // Create a buffer with multipart form data
_, err = fw.Write([]byte("TheExpectedModel")) var b bytes.Buffer
assert.NoError(t, err) w := multipart.NewWriter(&b)
// Add a file field // Add the model field
fw, err = w.CreateFormFile("file", "test.mp3") fw, err := w.CreateFormField("model")
assert.NoError(t, err) assert.NoError(t, err)
// Generate random content length between 10 and 20 _, err = fw.Write([]byte(tc.modelInput))
contentLength := rand.Intn(11) + 10 // 10 to 20 assert.NoError(t, err)
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data // Add a file field
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) fw, err = w.CreateFormFile("file", "test.mp3")
req.Header.Set("Content-Type", w.FormDataContentType()) assert.NoError(t, err)
rec := httptest.NewRecorder() // Generate random content length between 10 and 20
proxy.HandlerFunc(rec, req) contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Verify the response // Create the request with the multipart form data
assert.Equal(t, http.StatusOK, rec.Code) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
var response map[string]string req.Header.Set("Content-Type", w.FormDataContentType())
err = json.Unmarshal(rec.Body.Bytes(), &response) rec := httptest.NewRecorder()
assert.NoError(t, err) proxy.HandlerFunc(rec, req)
assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder // Verify the response
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"]) assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
} }
// Test useModelName in configuration sends overrides what is sent to upstream func TestProxyManager_SplitRequestedModel(t *testing.T) {
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": modelConfig,
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
requestedModel := "model1"
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
tests := []struct { tests := []struct {
name string name string
method string requestedModel string
requestHeaders map[string]string expectedProfile string
expectedStatus int expectedModel string
expectedHeaders map[string]string
}{ }{
{ {"no profile", "gpt-4", "", "gpt-4"},
name: "OPTIONS with no headers", {"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
method: "OPTIONS", {"only profile", "profile1:", "profile1", ""},
expectedStatus: http.StatusNoContent, {"empty model", ":gpt-4", "", "gpt-4"},
expectedHeaders: map[string]string{ {"empty profile", ":", "", ""},
"Access-Control-Allow-Origin": "*", {"no split char", "gpt-4", "", "gpt-4"},
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", {"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
proxy := New(config) profileName, modelName := splitRequestedModel(tt.requestedModel)
defer proxy.StopProcesses() if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
} }
if modelName != tt.expectedModel {
w := httptest.NewRecorder() t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
} }
}) })
} }
} }
func TestProxyManager_Upstream(t *testing.T) { // Test useModelName in configuration sends overrides what is sent to upstream
config := AddDefaultGroupToConfig(Config{ func TestProxyManager_UseModelName(t *testing.T) {
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config) upstreamModelName := "upstreamModel"
defer proxy.StopProcesses()
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) { modelConfig := getTestSimpleResponderConfig(upstreamModelName)
config := AddDefaultGroupToConfig(Config{ modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Profiles: map[string][]string{
"model1": getTestSimpleResponderConfig("model1"), "test": {"model1"},
}, },
LogLevel: "error",
}) Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") tests := []struct {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) description string
w := httptest.NewRecorder() requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
} }
-43
View File
@@ -1,43 +0,0 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
-77
View File
@@ -1,77 +0,0 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}