Compare commits

..

2 Commits

Author SHA1 Message Date
Benson Wong 014a2fa9a3 fix bug checking incorrect error 2025-03-20 15:26:39 -07:00
Benson Wong 5ceaef6144 add override for windows 2025-03-20 13:21:03 -07:00
30 changed files with 777 additions and 2432 deletions
-15
View File
@@ -1,15 +0,0 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
poem: false
review_status: true
collapse_walkthrough: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
+3 -3
View File
@@ -13,11 +13,11 @@ jobs:
steps: steps:
- uses: actions/stale@v9 - uses: actions/stale@v9
with: with:
days-before-issue-stale: 14 days-before-issue-stale: 30
days-before-issue-close: 14 days-before-issue-close: 14
stale-issue-label: "stale" stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity." stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale." close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
+2 -2
View File
@@ -20,10 +20,10 @@ clean:
rm -rf $(BUILD_DIR) rm -rf $(BUILD_DIR)
test: test:
go test -short -v -count=1 ./proxy go test -short -v ./proxy
test-all: test-all:
go test -v -count=1 ./proxy go test -v ./proxy
# Build OSX binary # Build OSX binary
mac: mac:
+29 -85
View File
@@ -1,7 +1,4 @@
![llama-swap header image](header2.png) ![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
@@ -26,7 +23,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58)) - `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - ✅ Docker and Podman support
@@ -36,7 +33,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## config.yaml ## config.yaml
@@ -67,16 +64,8 @@ models:
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Valid log levels: debug, info (default), warn, error # Write HTTP logs (useful for troubleshooting), defaults to false
logLevel: info logRequests: true
# Automatic Port Values
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
# when you use ${PORT} you can omit a custom model.proxy value, as it will
# default to http://localhost:${PORT}
# override the default port (5800) for automatic port values
startPort: 10001
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
@@ -91,7 +80,6 @@ models:
- "CUDA_VISIBLE_DEVICES=0" - "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match # where to reach the server started by cmd, make sure the ports match
# can be omitted if you use an automatic ${PORT} in cmd
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# aliases names to use this model for # aliases names to use this model for
@@ -118,69 +106,27 @@ models:
# but they can still be requested as normal # but they can still be requested as normal
"qwen-unlisted": "qwen-unlisted":
unlisted: true unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0 cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (v26.1.4+ required!) # Docker Support (v26.1.4+ required!)
"docker-llama": "docker-llama":
proxy: "http://127.0.0.1:${PORT}" proxy: "http://127.0.0.1:9790"
cmd: > cmd: >
docker run --name dockertest docker run --name dockertest
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models --init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# Groups provide advanced controls over model swapping behaviour. Using groups # profiles eliminates swapping by running multiple models at the same time
# some models can be kept loaded indefinitely, while others are swapped out.
# #
# Tips: # Tips:
# # - each model must be listening on a unique address and port
# - models must be defined above in the Models section # - the model name is in this format: "profile_name:model", like "coding:qwen"
# - a model can only be a member of one group # - the profile will load and unload all models in the profile at the same time
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields profiles:
# - see issue #109 for details coding:
#
# NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 is the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap controls the model swapping behaviour in within the group
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive controls how the group affects other groups
# - true: causes all other groups to unload their models when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
members:
- "llama" - "llama"
- "qwen-unlisted" - "qwen-unlisted"
# models in this group are never unloaded
"group2":
swap: false
exclusive: false
members:
- "docker-llama"
# (not defined above, here for example)
- "modelA"
- "modelB"
"forever":
# setting persistent to true causes the group to never be affected by the swapping behaviour of
# other groups. It is a shortcut to keeping some models always loaded.
persistent: true
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
``` ```
### Use Case Examples ### Use Case Examples
@@ -189,13 +135,18 @@ groups:
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases. - [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest. - [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations. - [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
## Configuration
llama-s
</details> </details>
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) ## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker is the quickest way to try out llama-swap: Docker is the quickest way to try out llama-swap:
```shell ```
# use CPU inference # use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu $ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
@@ -231,7 +182,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration. Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```shell ```
$ docker run -it --rm --runtime nvidia -p 9292:8080 \ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \ -v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \ -v /path/to/custom/config.yaml:/app/config.yaml \
@@ -246,12 +197,7 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
1. Create a configuration file, see [config.example.yaml](config.example.yaml) 1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Run the binary with `llama-swap --config path/to/config.yaml`. 1. Run the binary with `llama-swap --config path/to/config.yaml`
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source ### Building from source
@@ -266,19 +212,13 @@ Open the `http://<host>/logs` with your browser to get a web interface with stre
Of course, CLI access is also supported: Of course, CLI access is also supported:
```shell ```
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
# streams combined logs # streams logs
curl -Ns 'http://host/logs/stream' curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
@@ -320,4 +260,8 @@ WantedBy=multi-user.target
## Star History ## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date) <picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+21 -17
View File
@@ -1,24 +1,17 @@
# Seconds to wait for llama.cpp to be available to serve requests # Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 90 healthCheckTimeout: 15
# valid log levels: debug, info (default), warn, error # Log HTTP requests helpful for troubleshoot, defaults to False
logLevel: debug logRequests: true
# creating a coding profile with models for code generation and general questions
groups:
coding:
swap: false
members:
- "qwen"
- "llama"
models: models:
"llama": "llama":
cmd: > cmd: >
models/llama-server-osx models/llama-server-osx
--port ${PORT} --port 9001
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf -m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:9001
# list of model name aliases this llama.cpp instance can serve # list of model name aliases this llama.cpp instance can serve
aliases: aliases:
@@ -31,15 +24,17 @@ models:
ttl: 5 ttl: 5
"qwen": "qwen":
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9002
aliases: aliases:
- gpt-3.5-turbo - gpt-3.5-turbo
# Embedding example with Nomic # Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF # https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic": "nomic":
proxy: http://127.0.0.1:9005
cmd: > cmd: >
models/llama-server-osx --port ${PORT} models/llama-server-osx --port 9005
-m models/nomic-embed-text-v1.5.Q8_0.gguf -m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192 --ctx-size 8192
--batch-size 8192 --batch-size 8192
@@ -51,17 +46,19 @@ models:
# Reranking example with bge-reranker # Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF # https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker": "bge-reranker":
proxy: http://127.0.0.1:9006
cmd: > cmd: >
models/llama-server-osx --port ${PORT} models/llama-server-osx --port 9006
-m models/bge-reranker-v2-m3-Q4_K_M.gguf -m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192 --ctx-size 8192
--reranking --reranking
# Docker Support (v26.1.4+ required!) # Docker Support (v26.1.4+ required!)
"dockertest": "dockertest":
proxy: "http://127.0.0.1:9790"
cmd: > cmd: >
docker run --name dockertest docker run --name dockertest
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models --init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
@@ -70,7 +67,8 @@ models:
env: env:
- CUDA_VISIBLE_DEVICES=0,1 - CUDA_VISIBLE_DEVICES=0,1
- env1=hello - env1=hello
cmd: build/simple-responder --port ${PORT} cmd: build/simple-responder --port 8999
proxy: http://127.0.0.1:8999
unlisted: true unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail # use "none" to skip check. Caution this may cause some requests to fail
@@ -86,3 +84,9 @@ models:
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000 proxy: http://127.0.0.1:9000
unlisted: true unlisted: true
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"
-153
View File
@@ -1,153 +0,0 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -1,28 +0,0 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -1,26 +0,0 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
-49
View File
@@ -1,49 +0,0 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
+1 -2
View File
@@ -3,7 +3,6 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
@@ -38,7 +37,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
+18 -10
View File
@@ -9,16 +9,12 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@@ -27,8 +23,6 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
@@ -80,18 +74,32 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 351 KiB

+11 -141
View File
@@ -1,34 +1,25 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"syscall" "syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
) )
var ( var version string = "0"
version string = "0" var commit string = "abcd1234"
commit string = "abcd1234" var date = "unknown"
date string = "unknown"
)
func main() { func main() {
// Define a command-line flag for the port // Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name") configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port") listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build") showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
@@ -43,10 +34,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if len(config.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
if mode := os.Getenv("GIN_MODE"); mode != "" { if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode) gin.SetMode(mode)
} else { } else {
@@ -55,135 +42,18 @@ func main() {
proxyManager := proxy.New(config) proxyManager := proxy.New(config)
// Setup channels for server management
reloadChan := make(chan *proxy.ProxyManager)
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create server with initial handler
srv := &http.Server{
Addr: *listenStr,
Handler: proxyManager,
}
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() { go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { <-sigChan
fmt.Printf("Fatal server error: %v\n", err) fmt.Println("Shutting down llama-swap")
close(exitChan) proxyManager.Shutdown()
} os.Exit(0)
}() }()
// Handle config reloads and signals fmt.Println("llama-swap listening on " + *listenStr)
go func() { if err := proxyManager.Run(*listenStr); err != nil {
currentManager := proxyManager fmt.Printf("Server error: %v\n", err)
for { os.Exit(1)
select {
case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
// Now do a full shutdown to clear the process map
currentManager.Shutdown()
currentManager = newManager
srv.Handler = newManager
log.Println("Server handler updated with new config")
case sig := <-sigChan:
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentManager.Shutdown()
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
return
}
}
}()
// Start file watcher if requested
if *watchConfig {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
} else {
go watchConfigFileWithReload(absConfigPath, reloadChan)
}
}
// Wait for exit signal
<-exitChan
}
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
return
}
defer watcher.Close()
err = watcher.Add(configPath)
if err != nil {
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
return
}
log.Printf("Watching config file for changes: %s", configPath)
var debounceTimer *time.Timer
debounceDuration := 2 * time.Second
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
// We only care about writes to the specific config file
if event.Name == configPath && event.Has(fsnotify.Write) {
// Reset or start the debounce timer
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(debounceDuration, func() {
log.Printf("Config file modified: %s, reloading...", event.Name)
// Try up to 3 times with exponential backoff
var newConfig proxy.Config
var err error
for retries := 0; retries < 3; retries++ {
// Load new configuration
newConfig, err = proxy.LoadConfig(configPath)
if err == nil {
break
}
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
if retries < 2 {
time.Sleep(time.Duration(1<<retries) * time.Second)
}
}
if err != nil {
log.Printf("Failed to load new config after retries: %v", err)
return
}
// Create new ProxyManager with new config
newPM := proxy.New(newConfig)
reloadChan <- newPM
log.Println("Config reloaded successfully")
})
}
case err, ok := <-watcher.Errors:
if !ok {
log.Println("File watcher error channel closed.")
return
}
log.Printf("File watcher error: %v", err)
}
} }
} }
+7 -48
View File
@@ -26,8 +26,6 @@ func main() {
silent := flag.Bool("silent", false, "disable all logging") silent := flag.Bool("silent", false, "disable all logging")
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
// Create a new Gin router // Create a new Gin router
@@ -35,17 +33,14 @@ func main() {
// Set up the handler function using the provided response message // Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) { r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "text/plain")
// add a wait to simulate a slow query // add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil { if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait) time.Sleep(wait)
} }
c.JSON(http.StatusOK, gin.H{ c.String(200, *responseMessage)
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
})
}) })
// for issue #62 to check model name strips profile slug // for issue #62 to check model name strips profile slug
@@ -68,11 +63,8 @@ func main() {
}) })
r.POST("/v1/completions", func(c *gin.Context) { r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "text/plain")
c.JSON(http.StatusOK, gin.H{ c.String(200, *responseMessage)
"responseMessage": *responseMessage,
})
}) })
// issue #41 // issue #41
@@ -112,10 +104,6 @@ func main() {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize), "text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model, "model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
}) })
}) })
@@ -192,10 +180,6 @@ func main() {
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
} }
if !*silent {
fmt.Printf("My PID: %d\n", os.Getpid())
}
go func() { go func() {
log.Printf("simple-responder listening on %s\n", address) log.Printf("simple-responder listening on %s\n", address)
// service connections // service connections
@@ -206,36 +190,11 @@ func main() {
// Wait for interrupt signal to gracefully shutdown the server with // Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds. // a timeout of 5 seconds.
sigChan := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM // kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT // kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
countSigInt := 0
runloop:
for {
signal := <-sigChan
switch signal {
case syscall.SIGINT:
countSigInt++
if countSigInt > 1 {
break runloop
} else {
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
}
case syscall.SIGTERM:
if *ignoreSigTerm {
log.Println("Ignoring SIGTERM")
} else {
log.Println("Recieved SIGTERM, shutting down")
break runloop
}
default:
break runloop
}
}
log.Println("simple-responder shutting down") log.Println("simple-responder shutting down")
} }
+6 -154
View File
@@ -2,18 +2,13 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"sort"
"strconv"
"strings" "strings"
"github.com/google/shlex" "github.com/google/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct { type ModelConfig struct {
Cmd string `yaml:"cmd"` Cmd string `yaml:"cmd"`
Proxy string `yaml:"proxy"` Proxy string `yaml:"proxy"`
@@ -23,53 +18,20 @@ type ModelConfig struct {
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"` UseModelName string `yaml:"useModelName"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd) return SanitizeCommand(m.Cmd)
} }
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"` Models map[string]ModelConfig `yaml:"models"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// map aliases to actual model IDs // map aliases to actual model IDs
aliases map[string]string aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
} }
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
@@ -90,141 +52,31 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
} }
} }
func LoadConfig(path string) (Config, error) { func LoadConfig(path string) (*Config, error) {
file, err := os.Open(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return Config{}, err return nil, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
} }
var config Config var config Config
err = yaml.Unmarshal(data, &config) err = yaml.Unmarshal(data, &config)
if err != nil { if err != nil {
return Config{}, err return nil, err
} }
if config.HealthCheckTimeout < 15 { if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15 config.HealthCheckTimeout = 15
} }
// set default port ranges
if config.StartPort == 0 {
// default to 5800
config.StartPort = 5800
} else if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map // Populate the aliases map
config.aliases = make(map[string]string) config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models { for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases { for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName config.aliases[alias] = modelName
} }
} }
// iterate over the models and replace any ${PORT} with the next available port return &config, nil
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
// iterate over the sorted models
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
if strings.Contains(modelConfig.Cmd, "${PORT}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
if modelConfig.Proxy == "" {
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
} else {
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
}
nextPort++
config.Models[modelId] = modelConfig
} else if modelConfig.Proxy == "" {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
} }
func SanitizeCommand(cmdStr string) ([]string, error) { func SanitizeCommand(cmdStr string) ([]string, error) {
+1 -186
View File
@@ -3,7 +3,6 @@ package proxy
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -36,32 +35,11 @@ models:
aliases: aliases:
- "m2" - "m2"
checkEndpoint: "/" checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8082"
checkEndpoint: "/"
healthCheckTimeout: 15 healthCheckTimeout: 15
profiles: profiles:
test: test:
- model1 - model1
- model2 - model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
` `
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -74,8 +52,7 @@ groups:
t.Fatalf("Failed to load config: %v", err) t.Fatalf("Failed to load config: %v", err)
} }
expected := Config{ expected := &Config{
StartPort: 5800,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -91,18 +68,6 @@ groups:
Env: nil, Env: nil,
CheckEndpoint: "/", CheckEndpoint: "/",
}, },
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: nil,
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
},
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{ Profiles: map[string][]string{
@@ -112,25 +77,6 @@ groups:
"m1": "model1", "m1": "model1",
"model-one": "model1", "model-one": "model1",
"m2": "model2", "m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
}, },
} }
@@ -141,63 +87,6 @@ groups:
assert.Equal(t, "model1", realname) assert.Equal(t, "model1", realname)
} }
func TestConfig_GroupMemberIsUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// a Contains as order of the map is not guaranteed
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
}
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- m1
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
aliases:
- m1
- m2
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// this is a contains because it could be `model1` or `model2` depending on the order
// go decided on the order of the map
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{ config := &ModelConfig{
Cmd: `python model1.py \ Cmd: `python model1.py \
@@ -285,77 +174,3 @@ func TestConfig_SanitizeCommand(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, args) assert.Nil(t, args)
} }
func TestConfig_AutomaticPortAssignments(t *testing.T) {
t.Run("Default Port Ranges", func(t *testing.T) {
content := ``
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
})
t.Run("User specific port ranges", func(t *testing.T) {
content := `startPort: 1000`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 1000, config.StartPort)
})
t.Run("Invalid start port", func(t *testing.T) {
content := `startPort: abcd`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("start port must be greater than 1", func(t *testing.T) {
content := `startPort: -99`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("Automatic port assignments", func(t *testing.T) {
content := `
startPort: 5800
models:
model1:
cmd: svr --port ${PORT}
model2:
cmd: svr --port ${PORT}
proxy: "http://172.11.22.33:${PORT}"
model3:
cmd: svr --port 1999
proxy: "http://1.2.3.4:1999"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
})
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
content := `
models:
model1:
cmd: svr --port 111
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
})
}
+2 -18
View File
@@ -14,7 +14,6 @@ import (
var ( var (
nextTestPort int = 12000 nextTestPort int = 12000
portMutex sync.Mutex portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout)
) )
// Check if the binary exists // Check if the binary exists
@@ -27,17 +26,6 @@ func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
switch os.Getenv("LOG_LEVEL") {
case "debug":
testLogger.SetLogLevel(LevelDebug)
case "warn":
testLogger.SetLogLevel(LevelWarn)
case "info":
testLogger.SetLogLevel(LevelInfo)
default:
testLogger.SetLogLevel(LevelWarn)
}
m.Run() m.Run()
} }
@@ -48,18 +36,14 @@ func getSimpleResponderPath() string {
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
} }
func getTestPort() int { func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
portMutex.Lock() portMutex.Lock()
defer portMutex.Unlock() defer portMutex.Unlock()
port := nextTestPort port := nextTestPort
nextTestPort++ nextTestPort++
return port return getTestSimpleResponderConfigPort(expectedMessage, port)
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
} }
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
+71 -185
View File
@@ -12,65 +12,32 @@
flex-direction: column; flex-direction: column;
font-family: "Courier New", Courier, monospace; font-family: "Courier New", Courier, monospace;
} }
.log-container { #log-controls {
display: flex;
flex: 1;
gap: 0.5em;
margin: 0.5em; margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex; display: flex;
flex-direction: column; align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1; flex: 1;
min-width: 0;
transition: flex 0.3s ease;
} }
.log-column.minimized { #log-controls input:focus {
flex: 0.1; outline: none; /* Ensures no outline is shown when the input is focused */
max-width: 50px;
border: 1px solid #777;
color: green;
} }
.log-controls { #log-stream {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1; flex: 1;
margin: 0.5em;
padding: 1em; padding: 1em;
background: #f4f4f4; background: #f4f4f4;
overflow-y: auto; overflow-y: auto;
white-space: pre-wrap; white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; word-wrap: break-word; /* Ensures long words wrap */
min-height: 0;
} }
.regex-error { .regex-error {
background-color: #ff0000 !important; background-color: #ff0000 !important;
} }
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */ /* Dark mode styles */
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
body { body {
@@ -78,182 +45,101 @@
color: #fff; color: #fff;
} }
.log-stream { #log-stream {
background: #444; background: #444;
color: #fff; color: #fff;
} }
.log-controls input { #log-controls input {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
.log-controls button { #log-controls button {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
} }
</style> </style>
</head> </head>
<body> <body>
<div class="log-container"> <pre id="log-stream">Waiting for logs...</pre>
<div class="log-column"> <div id="log-controls">
<h2>Proxy Logs</h2> <input type="text" id="filter-input" placeholder="regex filter">
<div class="log-controls"> <button id="clear-button">clear</button>
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div> </div>
<script> <script>
class LogStream { const logStream = document.getElementById('log-stream');
constructor(streamElement, filterInput, clearButton, endpoint) { const filterInput = document.getElementById('filter-input');
this.streamElement = streamElement; var logData = "";
this.filterInput = filterInput; let regexFilter = null;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
this.initialize(); function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
}
} }
initialize() { // poor-ai's react ¯\_(ツ)_/¯
this.filterInput.addEventListener('input', () => this.updateFilter()); function render() {
this.clearButton.addEventListener('click', () => { if (regexFilter) {
this.filterInput.value = ""; const lines = logData.split('\n');
this.regexFilter = null; const filteredLines = lines.filter(line => {
this.render(); return regexFilter === null || regexFilter.test(line);
}); });
this.setupEventSource();
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
} }
setupEventSource() { logStream.scrollTop = logStream.scrollHeight;
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.logData = this.logData.slice(-1024 * 100);
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
};
// Initial connection
connect();
}
render() {
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
} }
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try { try {
this.regexFilter = new RegExp(pattern); regexFilter = new RegExp(pattern);
} catch (e) { } catch (e) {
console.error("Invalid regex pattern:", e); console.error("Invalid regex pattern:", e);
this.regexFilter = null; regexFilter = null;
this.filterInput.classList.add('regex-error'); filterInput.classList.add('regex-error');
return; return
}
} else {
regexFilter = null;
} }
this.render(); render();
}
} }
// Initialize both log streams filterInput.addEventListener('input', updateFilter);
document.addEventListener('DOMContentLoaded', () => { document.getElementById('clear-button').addEventListener('click', () => {
new LogStream( filterInput.value = "";
document.getElementById('proxy-log-stream'), regexFilter = null;
document.getElementById('proxy-filter-input'), render();
document.getElementById('proxy-clear-button'),
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
});
});
}); });
setupEventSource();
updateFilter();
</script> </script>
</body> </body>
</html> </html>
-90
View File
@@ -2,21 +2,11 @@ package proxy
import ( import (
"container/ring" "container/ring"
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
) )
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool clients map[chan []byte]bool
mu sync.RWMutex mu sync.RWMutex
@@ -25,10 +15,6 @@ type LogMonitor struct {
// typically this can be os.Stdout // typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels
level LogLevel
prefix string
} }
func NewLogMonitor() *LogMonitor { func NewLogMonitor() *LogMonitor {
@@ -40,8 +26,6 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool), clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo,
prefix: "",
} }
} }
@@ -110,77 +94,3 @@ func (w *LogMonitor) broadcast(msg []byte) {
} }
} }
} }
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+30 -155
View File
@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -30,23 +29,11 @@ const (
StateShutdown ProcessState = ProcessState("shutdown") StateShutdown ProcessState = ProcessState("shutdown")
) )
type StopStrategy int
const (
StopImmediately StopStrategy = iota
StopWaitForInflightRequest
)
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor
// for p.cmd.Wait() select { ... }
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration healthCheckLoopInterval time.Duration
@@ -64,52 +51,23 @@ type Process struct {
// for managing shutdown state // for managing shutdown state
shutdownCtx context.Context shutdownCtx context.Context
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
// for managing concurrency limits
concurrencyLimitSemaphore chan struct{}
// stop timeout waiting for graceful shutdown
gracefulStopTimeout time.Duration
// track that this happened
upstreamWasStoppedWithKill bool
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
} else {
proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID)
}
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
cmdWaitChan: make(chan error, 1), logMonitor: logMonitor,
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped, state: StateStopped,
shutdownCtx: ctx, shutdownCtx: ctx,
shutdownCancel: cancel, shutdownCancel: cancel,
// concurrency limit
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
// stop timeout
gracefulStopTimeout: 5 * time.Second,
upstreamWasStoppedWithKill: false,
} }
} }
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state // custom error types for swapping state
var ( var (
ErrExpectedStateMismatch = errors.New("expected state mismatch") ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -123,17 +81,14 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
defer p.stateMutex.Unlock() defer p.stateMutex.Unlock()
if p.state != expectedState { if p.state != expectedState {
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
return p.state, ErrExpectedStateMismatch return p.state, ErrExpectedStateMismatch
} }
if !isValidTransition(p.state, newState) { if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
return p.state, ErrInvalidStateTransition return p.state, ErrInvalidStateTransition
} }
p.state = newState p.state = newState
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
return p.state, nil return p.state, nil
} }
@@ -197,8 +152,8 @@ func (p *Process) start() error {
defer p.waitStarting.Done() defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.processLogger p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.processLogger p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
@@ -214,22 +169,6 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
// Capture the exit error for later signalling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
if p.upstreamWasStoppedWithKill {
p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr)
p.upstreamWasStoppedWithKill = false
return
}
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage: // One of three things can happen at this stage:
// 1. The command exits unexpectedly // 1. The command exits unexpectedly
// 2. The health check fails // 2. The health check fails
@@ -273,34 +212,17 @@ func (p *Process) start() error {
} }
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely but successfully")
}
}
default: default:
if err := p.checkHealthEndpoint(healthURL); err == nil { if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
cancelHealthCheck() cancelHealthCheck()
break loop break loop
} else { } else {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds()) fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else { } else {
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err) fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
} }
} }
} }
@@ -324,7 +246,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
} }
@@ -339,102 +261,77 @@ func (p *Process) start() error {
} }
} }
// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() { func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.StopImmediately()
}
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
return return
} }
// stop the process with a graceful exit timeout // stop the process with a graceful exit timeout
p.stopCommand(p.gracefulStopTimeout) p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
} }
} }
// Shutdown is called when llama-swap is shutting down. It will give a little bit // Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down. Once a process is in // is in the state of starting, it will cancel it and shut it down
// the StateShutdown state, it can not be started again.
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
p.shutdownCancel() p.shutdownCancel()
p.stopCommand(p.gracefulStopTimeout) p.stopCommand(5 * time.Second)
p.state = StateShutdown p.state = StateShutdown
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL. // If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) { func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout() defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { p.cmd.Process.Signal(syscall.SIGTERM)
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
}
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.upstreamWasStoppedWithKill = true p.cmd.Process.Kill()
if err := p.cmd.Process.Kill(); err != nil { case err := <-sigtermNormal:
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
}
case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that the cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else { } else {
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
} }
} else { } else {
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err) fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
} }
} }
} }
} }
func (p *Process) checkHealthEndpoint(healthURL string) error { func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{ client := &http.Client{
Timeout: 500 * time.Millisecond, Timeout: 500 * time.Millisecond,
} }
@@ -459,8 +356,6 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable // prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState() currentState := p.CurrentState()
@@ -469,14 +364,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
select {
case p.concurrencyLimitSemaphore <- struct{}{}:
defer func() { <-p.concurrencyLimitSemaphore }()
default:
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
p.inFlightRequests.Add(1) p.inFlightRequests.Add(1)
defer func() { defer func() {
p.lastRequestHandled = time.Now() p.lastRequestHandled = time.Now()
@@ -485,13 +372,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand // start the process on demand
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway) http.Error(w, errstr, http.StatusBadGateway)
return return
} }
startDuration = time.Since(beginStartTime)
} }
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
@@ -502,12 +387,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
req.Header = r.Header.Clone() req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway) http.Error(w, err.Error(), http.StatusBadGateway)
@@ -541,8 +420,4 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
} }
-9
View File
@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
-14
View File
@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+14 -146
View File
@@ -2,6 +2,7 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -12,26 +13,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", 5, config, debugLogger, debugLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
@@ -64,10 +52,11 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start() // are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) { func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -95,7 +84,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("broken", 1, config, debugLogger, debugLogger) process := NewProcess("broken", 1, config, NewLogMonitor())
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -120,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -162,7 +151,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter) assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -189,7 +178,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345" expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, debugLogger, debugLogger) process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
results := map[string]string{ results := map[string]string{
@@ -266,8 +255,9 @@ func TestProcess_SwapState(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) p := &Process{
p.state = test.currentState state: test.currentState,
}
resultState, err := p.swapState(test.expectedState, test.newState) resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
@@ -292,6 +282,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test") t.Skip("skipping long shutdown test")
} }
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong // make a config where the healthcheck will always fail because port is wrong
@@ -299,7 +290,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test" config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// make it a lot faster // make it a lot faster
process.healthCheckLoopInterval = time.Second process.healthCheckLoopInterval = time.Second
@@ -320,126 +311,3 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown") assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState()) assert.Equal(t, StateShutdown, process.CurrentState())
} }
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}
func TestProcess_ConcurrencyLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping long concurrency limit test")
}
expectedMessage := "concurrency_limit_test"
config := getTestSimpleResponderConfig(expectedMessage)
// only allow 1 concurrent request at a time
config.ConcurrencyLimit = 1
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
defer process.Stop()
// launch a goroutine first to take up the semaphore
go func() {
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req1)
assert.Equal(t, http.StatusOK, w.Code)
}()
// let the goroutine start
<-time.After(time.Millisecond * 25)
denied := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
func TestProcess_StopImmediately(t *testing.T) {
expectedMessage := "test_stop_immediate"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) {
expectedMessage := "test_sigkill"
binaryPath := getSimpleResponderPath()
port := getTestPort()
config := ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
// reduce to make testing go faster
process.gracefulStopTimeout = time.Second
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
waitChan := make(chan struct{})
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
// StatusOK because that was already sent before the kill
assert.Equal(t, http.StatusOK, w.Code)
// unexpected EOF because the kill happened, the "1" is sent before the kill
// then the unexpected EOF is sent after the kill
assert.Equal(t, "1unexpected EOF\n", w.Body.String())
close(waitChan)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
// the request should have been interrupted by SIGKILL
<-waitChan
}
-114
View File
@@ -1,114 +0,0 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sync"
)
type ProcessGroup struct {
sync.Mutex
config Config
id string
swap bool
exclusive bool
persistent bool
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// map of current processes
processes map[string]*Process
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
}
pg := &ProcessGroup{
id: id,
config: config,
swap: groupConfig.Swap,
exclusive: groupConfig.Exclusive,
persistent: groupConfig.Persistent,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
processes: make(map[string]*Process),
}
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
return pg
}
// ProxyRequest proxies a request to the specified model
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
if !pg.HasMember(modelID) {
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
}
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
pg.lastUsedProcess = modelID
}
pg.Unlock()
}
pg.processes[modelID].ProxyRequest(writer, request)
return nil
}
func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
if len(pg.processes) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
}(process)
}
wg.Wait()
}
func (pg *ProcessGroup) Shutdown() {
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
-96
View File
@@ -1,96 +0,0 @@
package proxy
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model2"},
},
"G2": {
Swap: false,
Exclusive: true,
Members: []string{"model3", "model4"},
},
},
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
func TestProcessGroup_HasMember(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model1"))
assert.True(t, pg.HasMember("model2"))
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model3", "model4"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
})
}
// make sure all the processes are running
for _, process := range pg.processes {
assert.Equal(t, StateReady, process.CurrentState())
}
}
+176 -176
View File
@@ -7,7 +7,6 @@ import (
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -26,67 +25,21 @@ const (
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
config Config config *Config
currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
processGroups map[string]*ProcessGroup
} }
func New(config Config) *ProxyManager { func New(config *Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: config,
currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(), ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
processGroups: make(map[string]*ProcessGroup),
} }
// create the process groups if config.LogRequests {
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// Start timer // Start timer
start := time.Now() start := time.Now()
@@ -105,8 +58,9 @@ func (pm *ProxyManager) setupGinEngine() {
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
bodySize := c.Writer.Size() bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP, clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method, method,
path, path,
c.Request.Proto, c.Request.Proto,
@@ -116,26 +70,16 @@ func (pm *ProxyManager) setupGinEngine() {
duration, duration,
) )
}) })
}
// see: issue: #81, #77 and #42 for CORS issues // see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
// allow whatever the client requested by default c.AbortWithStatus(204)
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return return
} }
c.Next() c.Next()
@@ -160,8 +104,6 @@ func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
@@ -197,77 +139,63 @@ func (pm *ProxyManager) setupGinEngine() {
// Disable console color for testing // Disable console color for testing
gin.DisableConsoleColor() gin.DisableConsoleColor()
return pm
} }
// ServeHTTP implements http.Handler interface func (pm *ProxyManager) Run(addr ...string) error {
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { return pm.ginEngine.Run(addr...)
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r) pm.ginEngine.ServeHTTP(w, r)
} }
// StopProcesses acquires a lock and stops all running upstream processes. func (pm *ProxyManager) StopProcesses() {
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
// stop Processes in parallel pm.stopProcesses()
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.StopProcesses(strategy)
}(processGroup)
}
wg.Wait()
} }
// Shutdown stops all processes managed by this ProxyManager // for internal usage
func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
pm.currentProcesses = make(map[string]*Process)
}
// Shutdown is called to shutdown all upstream processes
// when llama-swap is shutting down.
func (pm *ProxyManager) Shutdown() { func (pm *ProxyManager) Shutdown() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
pm.proxyLogger.Debug("Shutdown() called in proxy manager") // shutdown process in parallel
var wg sync.WaitGroup var wg sync.WaitGroup
// Send shutdown signal to all process in groups for _, process := range pm.currentProcesses {
for _, processGroup := range pm.processGroups {
wg.Add(1) wg.Add(1)
go func(processGroup *ProcessGroup) { go func(process *Process) {
defer wg.Done() defer wg.Done()
processGroup.Shutdown() process.Shutdown()
}(processGroup) }(process)
} }
wg.Wait() wg.Wait()
} }
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses(StopWaitForInflightRequest)
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := []interface{}{}
for id, modelConfig := range pm.config.Models { for id, modelConfig := range pm.config.Models {
@@ -297,6 +225,78 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
} }
} }
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id") requestedModel := c.Param("model_id")
@@ -305,15 +305,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
processGroup, _, err := pm.swapProcessGroup(requestedModel) if process, err := pm.swapModel(requestedModel); err != nil {
if err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } else {
return
}
// rewrite the path // rewrite the path
c.Request.URL.Path = c.Param("upstreamPath") c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) process.ProxyRequest(c.Writer, c.Request)
}
} }
func (pm *ProxyManager) upstreamIndex(c *gin.Context) { func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
@@ -351,23 +349,32 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
requestedModel := gjson.GetBytes(bodyBytes, "model").String() requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" { if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) process, err := pm.swapModel(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} }
// issue #69 allow custom model names to be sent to upstream // issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName if process.config.UseModelName != "" {
if useModelName != "" { bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return return
} }
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -376,14 +383,16 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Del("transfer-encoding") c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { process.ProxyRequest(c.Writer, c.Request)
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} }
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form // Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -397,16 +406,15 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) // Swap to the requested model
process, err := pm.swapModel(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
} }
// We need to reconstruct the multipart form in any case since the body is consumed // Get profile name and model name from the requested model
// Create a new buffer for the reconstructed request profileName, modelName := splitRequestedModel(requestedModel)
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values // Copy all form values
for key, values := range c.Request.MultipartForm.Value { for key, values := range c.Request.MultipartForm.Value {
@@ -414,13 +422,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
fieldValue := value fieldValue := value
// If this is the model field and we have a profile, use just the model name // If this is the model field and we have a profile, use just the model name
if key == "model" { if key == "model" {
// # issue #69 allow custom model names to be sent to upstream if process.config.UseModelName != "" {
useModelName := pm.config.Models[realModelName].UseModelName fieldValue = process.config.UseModelName
} else if profileName != "" {
if useModelName != "" { fieldValue = modelName
fieldValue = useModelName
} else {
fieldValue = requestedModel
} }
} }
field, err := multipartWriter.CreateFormField(key) field, err := multipartWriter.CreateFormField(key)
@@ -481,16 +486,8 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header = c.Request.Header.Clone() modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying // Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { process.ProxyRequest(c.Writer, modifiedReq)
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} }
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -504,7 +501,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
} }
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately) pm.StopProcesses()
c.String(http.StatusOK, "OK") c.String(http.StatusOK, "OK")
} }
@@ -512,15 +509,14 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json") context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response. runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups { for _, process := range pm.currentProcesses {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady { // Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{ runningProcesses = append(runningProcesses, gin.H{
"model": process.ID, "model": process.ID,
"state": process.state, "state": process.state,
}) })
}
}
} }
// Put the results under the `running` key. // Put the results under the `running` key.
@@ -531,11 +527,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.JSON(http.StatusOK, response) // Always return 200 OK context.JSON(http.StatusOK, response) // Always return 200 OK
} }
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup { func ProcessKeyName(groupName, modelName string) string {
for _, group := range pm.processGroups { return groupName + PROFILE_SPLIT_CHAR + modelName
if group.HasMember(modelName) { }
return group
} func splitRequestedModel(requestedModel string) (string, string) {
} profileName, modelName := "", requestedModel
return nil if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
} }
+8 -37
View File
@@ -9,6 +9,7 @@ import (
) )
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") { if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
@@ -27,7 +28,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
} }
} else { } else {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
history := pm.muxLogger.GetHistory() history := pm.logMonitor.GetHistory()
_, err := c.Writer.Write(history) _, err := c.Writer.Write(history)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
@@ -41,14 +42,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked") c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -61,7 +56,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped // Send history first if not skipped
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.Writer.Write(history) c.Writer.Write(history)
flusher.Flush() flusher.Flush()
@@ -90,21 +85,15 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
// Send history first if not skipped // Send history first if not skipped
_, skipHistory := c.GetQuery("no-history") _, skipHistory := c.GetQuery("no-history")
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.SSEvent("message", string(history)) c.SSEvent("message", string(history))
c.Writer.Flush() c.Writer.Flush()
@@ -122,21 +111,3 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
} }
} }
} }
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+266 -242
View File
@@ -8,7 +8,6 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -17,110 +16,77 @@ import (
) )
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses()
for _, modelName := range []string{"model1", "model2"} { for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
} }
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
} }
func TestProxyManager_SwapMultiProcess(t *testing.T) { func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), model1: getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), model2: getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", Profiles: map[string][]string{
Groups: map[string]GroupConfig{ "test": {model1, model2},
"G1": {
Swap: true,
Exclusive: false,
Members: []string{"model1"},
}, },
"G2": { }
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
},
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses()
tests := []string{"model1", "model2"} for modelID, requestedModel := range map[string]string{
for _, requestedModel := range tests { "model1": profileModel1,
t.Run(requestedModel, func(t *testing.T) { "model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), modelID)
})
} }
// make sure there's two loaded models // make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) assert.Len(t, proxy.currentProcesses, 2)
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) _, exists := proxy.currentProcesses[profileModel1]
} assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
// Test that a persistent group is not affected by the swapping behaviour of _, exists = proxy.currentProcesses[profileModel2]
// other groups. assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model2"},
},
},
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"}
for _, requestedModel := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel)
}
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
} }
// When a request for a different model comes in ProxyManager should wait until // When a request for a different model comes in ProxyManager should wait until
@@ -130,18 +96,17 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test") t.Skip("skipping slow test")
} }
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses()
results := map[string]string{} results := map[string]string{}
@@ -157,16 +122,15 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key) t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
} }
mu.Lock() mu.Lock()
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) results[key] = w.Body.String()
results[key] = response["responseMessage"]
mu.Unlock() mu.Unlock()
}(key) }(key)
@@ -182,14 +146,13 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
} }
func TestProxyManager_ListModelsHandler(t *testing.T) { func TestProxyManager_ListModelsHandler(t *testing.T) {
config := Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error",
} }
proxy := New(config) proxy := New(config)
@@ -200,7 +163,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Call the listModelsHandler // Call the listModelsHandler
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
// Check the response status code // Check the response status code
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -250,6 +213,50 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
assert.Empty(t, expectedModels, "not all expected models were returned") assert.Empty(t, expectedModels, "not all expected models were returned")
} }
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -261,27 +268,23 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/" model3Config.Proxy = "http://localhost:10003/"
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": model3Config, "model3": model3Config,
}, },
LogLevel: "error", }
Groups: map[string]GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
},
},
})
proxy := New(config) proxy := New(config)
// Start all the processes // Start all the processes
var wg sync.WaitGroup var wg sync.WaitGroup
for _, modelName := range []string{"model1", "model2", "model3"} { for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
wg.Add(1) wg.Add(1)
go func(modelName string) { go func(modelName string) {
defer wg.Done() defer wg.Done()
@@ -289,10 +292,11 @@ func TestProxyManager_Shutdown(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up // send a request to trigger the proxy to load
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName) }(modelName)
} }
@@ -304,43 +308,64 @@ func TestProxyManager_Shutdown(t *testing.T) {
} }
func TestProxyManager_Unload(t *testing.T) { func TestProxyManager_Unload(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") proc, err := proxy.swapModel("model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) assert.NoError(t, err)
w := httptest.NewRecorder() assert.NotNil(t, proc)
proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) assert.Len(t, proxy.currentProcesses, 1)
req = httptest.NewRequest("GET", "/unload", nil) req := httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// give it a bit of time to stop // issue 62, strip profile slug from model name
<-time.After(time.Millisecond * 250) func TestProxyManager_StripProfileSlug(t *testing.T) {
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
} }
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "warn", Profiles: map[string][]string{
}) "test": {"model1", "model2"},
},
}
// Define a helper struct to parse the JSON response. // Define a helper struct to parse the JSON response.
type RunningResponse struct { type RunningResponse struct {
@@ -352,12 +377,12 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Create proxy once for all tests // Create proxy once for all tests
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses()
t.Run("no models loaded", func(t *testing.T) { t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil) req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -375,13 +400,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
reqBody := `{"model":"model1"}` reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint. // Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil) req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
var response RunningResponse var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
@@ -395,20 +420,82 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Is the model loaded? // Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State) assert.Equal(t, "ready", response.Running[0].State)
}) })
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
} }
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses()
testCases := []struct {
name string
modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a buffer with multipart form data // Create a buffer with multipart form data
var b bytes.Buffer var b bytes.Buffer
w := multipart.NewWriter(&b) w := multipart.NewWriter(&b)
@@ -416,7 +503,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
// Add the model field // Add the model field
fw, err := w.CreateFormField("model") fw, err := w.CreateFormField("model")
assert.NoError(t, err) assert.NoError(t, err)
_, err = fw.Write([]byte("TheExpectedModel")) _, err = fw.Write([]byte(tc.modelInput))
assert.NoError(t, err) assert.NoError(t, err)
// Add a file field // Add a file field
@@ -433,49 +520,94 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req) proxy.HandlerFunc(rec, req)
// Verify the response // Verify the response
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response) err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"]) assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"]) })
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
} }
// Test useModelName in configuration sends overrides what is sent to upstream // Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) { func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel" upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName) modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName modelConfig.UseModelName = upstreamModelName
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": modelConfig, "model1": modelConfig,
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses()
requestedModel := "model1" tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) { for _, tt := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName) assert.Contains(t, w.Body.String(), upstreamModelName)
})
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) { })
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data // Create a buffer with multipart form data
var b bytes.Buffer var b bytes.Buffer
w := multipart.NewWriter(&b) w := multipart.NewWriter(&b)
@@ -483,7 +615,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
// Add the model field // Add the model field
fw, err := w.CreateFormField("model") fw, err := w.CreateFormField("model")
assert.NoError(t, err) assert.NoError(t, err)
_, err = fw.Write([]byte(requestedModel)) _, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err) assert.NoError(t, err)
// Add a file field // Add a file field
@@ -497,7 +629,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req) proxy.HandlerFunc(rec, req)
// Verify the response // Verify the response
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
@@ -506,114 +638,6 @@ func TestProxyManager_UseModelName(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"]) assert.Equal(t, upstreamModelName, response["model"])
}) })
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
func TestProxyManager_Upstream(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
} }
-43
View File
@@ -1,43 +0,0 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
-77
View File
@@ -1,77 +0,0 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}