Compare commits

..

1 Commits

Author SHA1 Message Date
Benson Wong cebf9c4d34 increase health check to a minimum of 5 seconds 2025-02-18 17:35:52 -08:00
33 changed files with 550 additions and 2607 deletions
-23
View File
@@ -1,23 +0,0 @@
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
name: Close inactive issues
on:
schedule:
- cron: "32 1 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 14
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
+1 -3
View File
@@ -15,9 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
#platform: [intel, cuda, vulkan, cpu, musa] platform: [intel, cuda, vulkan, cpu, musa]
platform: [cuda, vulkan, cpu, musa]
fail-fast: false
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
-32
View File
@@ -1,32 +0,0 @@
# This workflow will build a golang project
name: CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Test all
run: make test-all
-12
View File
@@ -16,15 +16,3 @@ builds:
goarch: arm64 goarch: arm64
- goos: windows - goos: windows
goarch: arm64 goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+35 -107
View File
@@ -1,7 +1,4 @@
![llama-swap header image](header2.png) ![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
@@ -14,29 +11,26 @@ Written in golang, it is very easy to install (single binary with no dependancie
- ✅ Easy to deploy: single binary with no dependencies - ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file - ✅ Easy to config: single yaml file
- ✅ On-demand model switching - ✅ On-demand model switching
- ✅ Full control over server settings per model
- ✅ OpenAI API supported endpoints: - ✅ OpenAI API supported endpoints:
- `v1/completions` - `v1/completions`
- `v1/chat/completions` - `v1/chat/completions`
- `v1/embeddings` - `v1/embeddings`
- `v1/rerank` - `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - ✅ Multiple GPU support
- ✅ llama-swap custom API endpoints
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - ✅ Docker and Podman support
-Full control over server settings per model -Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
## How does llama-swap work? ## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## config.yaml ## config.yaml
@@ -67,31 +61,15 @@ models:
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Valid log levels: debug, info (default), warn, error # Write HTTP logs (useful for troubleshooting), defaults to false
logLevel: info logRequests: true
# Automatic Port Values
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
# when you use ${PORT} you can omit a custom model.proxy value, as it will
# default to http://localhost:${PORT}
# override the default port (5800) for automatic port values
startPort: 10001
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
"llama": "llama":
# multiline for readability cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match # where to reach the server started by cmd, make sure the ports match
# can be omitted if you use an automatic ${PORT} in cmd
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# aliases names to use this model for # aliases names to use this model for
@@ -110,89 +88,49 @@ models:
# default: 0 = never unload model # default: 0 = never unload model
ttl: 60 ttl: 60
# `useModelName` overrides the model name in the request "qwen":
# and sends a specific name to the upstream server # environment variables to pass to the command
useModelName: "qwen:qwq" env:
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists # unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal # but they can still be requested as normal
"qwen-unlisted": "qwen-unlisted":
unlisted: true unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0 cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (v26.1.4+ required!) # Docker Support (v26.1.4+ required!)
"docker-llama": "docker-llama":
proxy: "http://127.0.0.1:${PORT}" proxy: "http://127.0.0.1:9790"
cmd: > cmd: >
docker run --name dockertest docker run --name dockertest
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models --init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# Groups provide advanced controls over model swapping behaviour. Using groups # profiles make it easy to managing multi model (and gpu) configurations.
# some models can be kept loaded indefinitely, while others are swapped out.
# #
# Tips: # Tips:
# # - each model must be listening on a unique address and port
# - models must be defined above in the Models section # - the model name is in this format: "profile_name:model", like "coding:qwen"
# - a model can only be a member of one group # - the profile will load and unload all models in the profile at the same time
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields profiles:
# - see issue #109 for details coding:
# - "qwen"
# NOTE: the example below uses model names that are not defined above for demonstration purposes - "llama"
groups:
# group1 is the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap controls the model swapping behaviour in within the group
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive controls how the group affects other groups
# - true: causes all other groups to unload their models when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
members:
- "llama"
- "qwen-unlisted"
# models in this group are never unloaded
"group2":
swap: false
exclusive: false
members:
- "docker-llama"
# (not defined above, here for example)
- "modelA"
- "modelB"
"forever":
# setting persistent to true causes the group to never be affected by the swapping behaviour of
# other groups. It is a shortcut to keeping some models always loaded.
persistent: true
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
``` ```
### Use Case Examples ### Advanced Examples
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints - [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases. - [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest. - [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
## Configuration
llama-s
</details> </details>
@@ -270,15 +208,9 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
# streams combined logs # streams logs
curl -Ns 'http://host/logs/stream' curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
@@ -317,7 +249,3 @@ StartLimitInterval=30
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target
``` ```
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests # Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 90 healthCheckTimeout: 15
# valid log levels: debug, info (default), warn, error # Log HTTP requests helpful for troubleshoot, defaults to False
logLevel: debug logRequests: true
models: models:
"llama": "llama":
-6
View File
@@ -38,12 +38,6 @@ else
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \ | jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}') | sort -r | head -n1 | awk -F '-' '{print $3}')
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}" CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}" CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER" echo "Building ${CONTAINER_TAG} $LS_VER"
-153
View File
@@ -1,153 +0,0 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -1,28 +0,0 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -1,26 +0,0 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
-49
View File
@@ -1,49 +0,0 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
@@ -1,51 +0,0 @@
# Restart llama-swap on config change
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
```bash
#!/bin/bash
#
# A simple watch and restart llama-swap when its configuration
# file changes. Useful for trying out configuration changes
# without manually restarting the server each time.
if [ -z "$1" ]; then
echo "Usage: $0 <path to config.yaml>"
exit 1
fi
while true; do
# Start the process again
./llama-swap-linux-amd64 -config $1 -listen :1867 &
PID=$!
echo "Started llama-swap with PID $PID"
# Wait for modifications in the specified directory or file
inotifywait -e modify "$1"
# Check if process exists before sending signal
if kill -0 $PID 2>/dev/null; then
echo "Sending SIGTERM to $PID"
kill -SIGTERM $PID
wait $PID
else
echo "Process $PID no longer exists"
fi
sleep 1
done
```
## Usage and output example
```bash
$ ./watch-and-restart.sh config.yaml
Started llama-swap with PID 495455
Setting up watches.
Watches established.
llama-swap listening on :1867
Sending SIGTERM to 495455
Shutting down llama-swap
Started llama-swap with PID 495486
Setting up watches.
Watches established.
llama-swap listening on :1867
```
+6 -10
View File
@@ -3,11 +3,7 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
@@ -19,10 +15,12 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
@@ -31,14 +29,12 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.38.0 // indirect golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
) )
-20
View File
@@ -57,16 +57,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -78,30 +68,20 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 351 KiB

-4
View File
@@ -34,10 +34,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if len(config.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
if mode := os.Getenv("GIN_MODE"); mode != "" { if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode) gin.SetMode(mode)
} else { } else {
+4 -75
View File
@@ -12,14 +12,12 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
) )
func main() { func main() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
// Define a command-line flag for the port // Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on") port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message // Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with") responseMessage := flag.String("respond", "hi", "message to respond with")
@@ -33,88 +31,19 @@ func main() {
// Set up the handler function using the provided response message // Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) { r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "text/plain")
// add a wait to simulate a slow query // add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil { if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait) time.Sleep(wait)
} }
c.JSON(http.StatusOK, gin.H{ c.String(200, *responseMessage)
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
})
})
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
}) })
r.POST("/v1/completions", func(c *gin.Context) { r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "text/plain")
c.JSON(http.StatusOK, gin.H{ c.String(200, *responseMessage)
"responseMessage": *responseMessage,
})
})
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
// Get the model from the form values
model := c.Request.FormValue("model")
if model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
return
}
// Get the file from the form
file, _, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
defer file.Close()
// Read the file content to get its size
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
return
}
fileSize := len(fileBytes)
// Return a JSON response with the model and transcription text including file size
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
})
}) })
r.GET("/slow-respond", func(c *gin.Context) { r.GET("/slow-respond", func(c *gin.Context) {
+6 -152
View File
@@ -2,18 +2,13 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"sort"
"strconv"
"strings" "strings"
"github.com/google/shlex" "github.com/google/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct { type ModelConfig struct {
Cmd string `yaml:"cmd"` Cmd string `yaml:"cmd"`
Proxy string `yaml:"proxy"` Proxy string `yaml:"proxy"`
@@ -22,51 +17,20 @@ type ModelConfig struct {
CheckEndpoint string `yaml:"checkEndpoint"` CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd) return SanitizeCommand(m.Cmd)
} }
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"` Models map[string]ModelConfig `yaml:"models"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// map aliases to actual model IDs // map aliases to actual model IDs
aliases map[string]string aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
} }
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
@@ -87,141 +51,31 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
} }
} }
func LoadConfig(path string) (Config, error) { func LoadConfig(path string) (*Config, error) {
file, err := os.Open(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return Config{}, err return nil, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
} }
var config Config var config Config
err = yaml.Unmarshal(data, &config) err = yaml.Unmarshal(data, &config)
if err != nil { if err != nil {
return Config{}, err return nil, err
} }
if config.HealthCheckTimeout < 15 { if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15 config.HealthCheckTimeout = 15
} }
// set default port ranges
if config.StartPort == 0 {
// default to 5800
config.StartPort = 5800
} else if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map // Populate the aliases map
config.aliases = make(map[string]string) config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models { for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases { for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName config.aliases[alias] = modelName
} }
} }
// iterate over the models and replace any ${PORT} with the next available port return &config, nil
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
// iterate over the sorted models
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
if strings.Contains(modelConfig.Cmd, "${PORT}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
if modelConfig.Proxy == "" {
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
} else {
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
}
nextPort++
config.Models[modelId] = modelConfig
} else if modelConfig.Proxy == "" {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
} }
func SanitizeCommand(cmdStr string) ([]string, error) { func SanitizeCommand(cmdStr string) ([]string, error) {
+1 -186
View File
@@ -3,7 +3,6 @@ package proxy
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -36,32 +35,11 @@ models:
aliases: aliases:
- "m2" - "m2"
checkEndpoint: "/" checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8082"
checkEndpoint: "/"
healthCheckTimeout: 15 healthCheckTimeout: 15
profiles: profiles:
test: test:
- model1 - model1
- model2 - model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
` `
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -74,8 +52,7 @@ groups:
t.Fatalf("Failed to load config: %v", err) t.Fatalf("Failed to load config: %v", err)
} }
expected := Config{ expected := &Config{
StartPort: 5800,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -91,18 +68,6 @@ groups:
Env: nil, Env: nil,
CheckEndpoint: "/", CheckEndpoint: "/",
}, },
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: nil,
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
},
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{ Profiles: map[string][]string{
@@ -112,25 +77,6 @@ groups:
"m1": "model1", "m1": "model1",
"model-one": "model1", "model-one": "model1",
"m2": "model2", "m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
}, },
} }
@@ -141,63 +87,6 @@ groups:
assert.Equal(t, "model1", realname) assert.Equal(t, "model1", realname)
} }
func TestConfig_GroupMemberIsUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// a Contains as order of the map is not guaranteed
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
}
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- m1
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
aliases:
- m1
- m2
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// this is a contains because it could be `model1` or `model2` depending on the order
// go decided on the order of the map
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{ config := &ModelConfig{
Cmd: `python model1.py \ Cmd: `python model1.py \
@@ -285,77 +174,3 @@ func TestConfig_SanitizeCommand(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, args) assert.Nil(t, args)
} }
func TestConfig_AutomaticPortAssignments(t *testing.T) {
t.Run("Default Port Ranges", func(t *testing.T) {
content := ``
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
})
t.Run("User specific port ranges", func(t *testing.T) {
content := `startPort: 1000`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 1000, config.StartPort)
})
t.Run("Invalid start port", func(t *testing.T) {
content := `startPort: abcd`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("start port must be greater than 1", func(t *testing.T) {
content := `startPort: -99`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("Automatic port assignments", func(t *testing.T) {
content := `
startPort: 5800
models:
model1:
cmd: svr --port ${PORT}
model2:
cmd: svr --port ${PORT}
proxy: "http://172.11.22.33:${PORT}"
model3:
cmd: svr --port 1999
proxy: "http://1.2.3.4:1999"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
})
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
content := `
models:
model1:
cmd: svr --port 111
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
})
}
-12
View File
@@ -14,7 +14,6 @@ import (
var ( var (
nextTestPort int = 12000 nextTestPort int = 12000
portMutex sync.Mutex portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout)
) )
// Check if the binary exists // Check if the binary exists
@@ -27,17 +26,6 @@ func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
switch os.Getenv("LOG_LEVEL") {
case "debug":
testLogger.SetLogLevel(LevelDebug)
case "warn":
testLogger.SetLogLevel(LevelWarn)
case "info":
testLogger.SetLogLevel(LevelInfo)
default:
testLogger.SetLogLevel(LevelWarn)
}
m.Run() m.Run()
} }
+78 -192
View File
@@ -12,65 +12,32 @@
flex-direction: column; flex-direction: column;
font-family: "Courier New", Courier, monospace; font-family: "Courier New", Courier, monospace;
} }
.log-container { #log-controls {
display: flex;
flex: 1;
gap: 0.5em;
margin: 0.5em; margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex; display: flex;
flex-direction: column; align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1; flex: 1;
min-width: 0;
transition: flex 0.3s ease;
} }
.log-column.minimized { #log-controls input:focus {
flex: 0.1; outline: none; /* Ensures no outline is shown when the input is focused */
max-width: 50px;
border: 1px solid #777;
color: green;
} }
.log-controls { #log-stream {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1; flex: 1;
margin: 0.5em;
padding: 1em; padding: 1em;
background: #f4f4f4; background: #f4f4f4;
overflow-y: auto; overflow-y: auto;
white-space: pre-wrap; white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; word-wrap: break-word; /* Ensures long words wrap */
min-height: 0;
} }
.regex-error { .regex-error {
background-color: #ff0000 !important; background-color: #ff0000 !important;
} }
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */ /* Dark mode styles */
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
body { body {
@@ -78,182 +45,101 @@
color: #fff; color: #fff;
} }
.log-stream { #log-stream {
background: #444; background: #444;
color: #fff; color: #fff;
} }
.log-controls input { #log-controls input {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
.log-controls button { #log-controls button {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
} }
</style> </style>
</head> </head>
<body> <body>
<div class="log-container"> <pre id="log-stream">Waiting for logs...</pre>
<div class="log-column"> <div id="log-controls">
<h2>Proxy Logs</h2> <input type="text" id="filter-input" placeholder="regex filter">
<div class="log-controls"> <button id="clear-button">clear</button>
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div> </div>
<script> <script>
class LogStream { const logStream = document.getElementById('log-stream');
constructor(streamElement, filterInput, clearButton, endpoint) { const filterInput = document.getElementById('filter-input');
this.streamElement = streamElement; var logData = "";
this.filterInput = filterInput; let regexFilter = null;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
this.initialize(); function setupEventSource() {
} if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
initialize() { eventSource.onmessage = function(event) {
this.filterInput.addEventListener('input', () => this.updateFilter()); logData += event.data;
this.clearButton.addEventListener('click', () => { render()
this.filterInput.value = "";
this.regexFilter = null;
this.render();
});
this.setupEventSource();
}
setupEventSource() {
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.logData = this.logData.slice(-1024 * 100);
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
}; };
// Initial connection eventSource.onerror = function(err) {
connect(); logData = "EventSource failed: " + err.message;
} };
} else {
render() { logData = "SSE Not supported by this browser."
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
}
try {
this.regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
this.regexFilter = null;
this.filterInput.classList.add('regex-error');
return;
}
this.render();
} }
} }
// Initialize both log streams // poor-ai's react ¯\_(ツ)_/¯
document.addEventListener('DOMContentLoaded', () => { function render() {
new LogStream( if (regexFilter) {
document.getElementById('proxy-log-stream'), const lines = logData.split('\n');
document.getElementById('proxy-filter-input'), const filteredLines = lines.filter(line => {
document.getElementById('proxy-clear-button'), return regexFilter === null || regexFilter.test(line);
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
}); });
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
}
logStream.scrollTop = logStream.scrollHeight;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
}
} else {
regexFilter = null;
}
render();
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
}); });
setupEventSource();
updateFilter();
</script> </script>
</body> </body>
</html> </html>
-90
View File
@@ -2,21 +2,11 @@ package proxy
import ( import (
"container/ring" "container/ring"
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
) )
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool clients map[chan []byte]bool
mu sync.RWMutex mu sync.RWMutex
@@ -25,10 +15,6 @@ type LogMonitor struct {
// typically this can be os.Stdout // typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels
level LogLevel
prefix string
} }
func NewLogMonitor() *LogMonitor { func NewLogMonitor() *LogMonitor {
@@ -40,8 +26,6 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool), clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo,
prefix: "",
} }
} }
@@ -110,77 +94,3 @@ func (w *LogMonitor) broadcast(msg []byte) {
} }
} }
} }
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+112 -179
View File
@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -31,18 +30,11 @@ const (
) )
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor
// for p.cmd.Wait() select { ... } healthCheckTimeout int
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time lastRequestHandled time.Time
@@ -59,70 +51,54 @@ type Process struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
cmdWaitChan: make(chan error, 1), logMonitor: logMonitor,
processLogger: processLogger, healthCheckTimeout: healthCheckTimeout,
proxyLogger: proxyLogger, state: StateStopped,
healthCheckTimeout: healthCheckTimeout, shutdownCtx: ctx,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ shutdownCancel: cancel,
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
} }
} }
// LogMonitor returns the log monitor associated with the process. func (p *Process) setState(newState ProcessState) error {
func (p *Process) LogMonitor() *LogMonitor { // enforce valid state transitions
return p.processLogger invalidTransition := false
} if p.state == StateStopped {
// stopped -> starting
// custom error types for swapping state if newState != StateStarting {
var ( invalidTransition = true
ErrExpectedStateMismatch = errors.New("expected state mismatch") }
ErrInvalidStateTransition = errors.New("invalid state transition") } else if p.state == StateStarting {
) // starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
// swapState performs a compare and swap of the state atomically. It returns the current state invalidTransition = true
// and an error if the swap failed. }
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) { } else if p.state == StateReady {
p.stateMutex.Lock() // ready -> stopping
defer p.stateMutex.Unlock() if newState != StateStopping {
invalidTransition = true
if p.state != expectedState { }
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState) } else if p.state == StateStopping {
return p.state, ErrExpectedStateMismatch // stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
} }
if !isValidTransition(p.state, newState) { if invalidTransition {
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState) //panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return p.state, ErrInvalidStateTransition return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
} }
p.state = newState p.state = newState
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) return nil
return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown:
return false // No transitions allowed from these states
}
return false
} }
func (p *Process) CurrentState() ProcessState { func (p *Process) CurrentState() ProcessState {
@@ -140,58 +116,50 @@ func (p *Process) start() error {
return fmt.Errorf("can not start(), upstream proxy missing") return fmt.Errorf("can not start(), upstream proxy missing")
} }
args, err := p.config.SanitizedCommand() // wait for the other start() to complete
if err != nil { curState := p.CurrentState()
return fmt.Errorf("unable to get sanitized command: %v", err)
if curState == StateReady {
return nil
} }
if curState, err := p.swapState(StateStopped, StateStarting); err != nil { if curState == StateStarting {
if err == ErrExpectedStateMismatch { p.waitStarting.Wait()
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error if state := p.CurrentState(); state != StateReady {
if curState == StateStarting { return fmt.Errorf("start() failed current state: %v", state)
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
} }
return nil
}
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if err := p.setState(StateStarting); err != nil {
return err
} }
p.waitStarting.Add(1) p.waitStarting.Add(1)
defer p.waitStarting.Done() defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.processLogger p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.processLogger p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
// Set process state to failed
if err != nil { if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil { p.setState(StateFailed)
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
// Capture the exit error for later signaling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage: // One of three things can happen at this stage:
// 1. The command exits unexpectedly // 1. The command exits unexpectedly
// 2. The health check fails // 2. The health check fails
@@ -223,51 +191,31 @@ func (p *Process) start() error {
) )
defer cancelHealthCheck() defer cancelHealthCheck()
// Health check loop
loop: loop:
// Ready Check loop
for { for {
select { select {
case <-checkDeadline.Done(): case <-checkDeadline.Done():
if curState, err := p.swapState(StateStarting, StateFailed); err != nil { p.setState(StateFailed)
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState) return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely but successfully")
}
}
default: default:
if err := p.checkHealthEndpoint(healthURL); err == nil { if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
cancelHealthCheck() cancelHealthCheck()
break loop break loop
} else { } else {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds()) fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else { } else {
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err) fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
} }
} }
} }
<-time.After(p.healthCheckLoopInterval) <-time.After(5 * time.Second)
} }
} }
@@ -278,7 +226,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) { for range time.Tick(time.Second) {
if p.CurrentState() != StateReady { if p.state != StateReady {
return return
} }
@@ -286,7 +234,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
} }
@@ -294,33 +242,26 @@ func (p *Process) start() error {
}() }()
} }
if curState, err := p.swapState(StateStarting, StateReady); err != nil { return p.setState(StateReady)
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
} }
func (p *Process) Stop() { func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.proxyLogger.Debugf("<%s> Stopping process", p.ID) p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if err := p.setState(StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
return return
} }
// stop the process with a graceful exit timeout // stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if err := p.setState(StateStopped); err != nil {
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState) panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
} }
} }
@@ -328,59 +269,65 @@ func (p *Process) Stop() {
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down // is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel() p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
p.state = StateShutdown if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL. // If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) { func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout() defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { p.cmd.Process.Signal(syscall.SIGTERM)
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err)
}
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill() p.cmd.Process.Kill()
case err := <-p.cmdWaitChan: case err := <-sigtermNormal:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else { } else {
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
} }
} else { } else {
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err) fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
} }
} }
} }
} }
func (p *Process) checkHealthEndpoint(healthURL string) error { func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{ client := &http.Client{
Timeout: 500 * time.Millisecond, Timeout: 500 * time.Millisecond,
} }
@@ -405,8 +352,6 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable // prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState() currentState := p.CurrentState()
@@ -423,13 +368,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand // start the process on demand
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway) http.Error(w, errstr, http.StatusBadGateway)
return return
} }
startDuration = time.Since(beginStartTime)
} }
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
@@ -440,12 +383,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
req.Header = r.Header.Clone() req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway) http.Error(w, err.Error(), http.StatusBadGateway)
@@ -479,8 +416,4 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
} }
-9
View File
@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
-14
View File
@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+39 -76
View File
@@ -2,6 +2,7 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -12,26 +13,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", 5, config, debugLogger, debugLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
@@ -64,10 +52,11 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start() // are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) { func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -95,7 +84,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("broken", 1, config, debugLogger, debugLogger) process := NewProcess("broken", 1, config, NewLogMonitor())
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -120,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -162,7 +151,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter) assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -180,8 +169,6 @@ func TestProcess_LowTTLValue(t *testing.T) {
} }
// issue #19 // issue #19
// This test makes sure using Process.Stop() does not affect pending HTTP
// requests. All HTTP requests in this test should complete successfully.
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping slow test") t.Skip("skipping slow test")
@@ -189,7 +176,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345" expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, debugLogger, debugLogger) process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
results := map[string]string{ results := map[string]string{
@@ -205,9 +192,8 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(key string) { go func(key string) {
defer wg.Done() defer wg.Done()
// send a request where simple-responder is will wait 300ms before responding // send a request that should take 5 * 200ms (1 second) to complete
// this will simulate an in-progress request. req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
@@ -223,9 +209,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
}(key) }(key)
} }
// Stop the process while requests are still being processed // stop the requests in the middle
go func() { go func() {
<-time.After(150 * time.Millisecond) <-time.After(500 * time.Millisecond)
process.Stop() process.Stop()
}() }()
@@ -236,40 +222,39 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
} }
} }
func TestProcess_SwapState(t *testing.T) { func TestSetState(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
currentState ProcessState currentState ProcessState
expectedState ProcessState
newState ProcessState newState ProcessState
expectedError error expectedError error
expectedResult ProcessState expectedResult ProcessState
}{ }{
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting}, {"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady}, {"Starting to Ready", StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed}, {"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping}, {"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping}, {"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped}, {"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown}, {"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped}, {"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting}, {"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady}, {"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady}, {"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping}, {"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed}, {"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed}, {"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown}, {"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown}, {"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) p := &Process{
p.state = test.currentState state: test.currentState,
}
resultState, err := p.swapState(test.expectedState, test.newState) err := p.setState(test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil { } else if err == nil && test.expectedError != nil {
@@ -280,8 +265,8 @@ func TestProcess_SwapState(t *testing.T) {
} }
} }
if resultState != test.expectedResult { if p.state != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState) t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
} }
}) })
} }
@@ -292,6 +277,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test") t.Skip("skipping long shutdown test")
} }
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong // make a config where the healthcheck will always fail because port is wrong
@@ -299,16 +285,13 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test" config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown // start a goroutine to simulate a shutdown
var wg sync.WaitGroup var wg sync.WaitGroup
go func() { go func() {
defer wg.Done() defer wg.Done()
<-time.After(time.Millisecond * 500) <-time.After(time.Second * 2)
process.Shutdown() process.Shutdown()
}() }()
wg.Add(1) wg.Add(1)
@@ -320,23 +303,3 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown") assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState()) assert.Equal(t, StateShutdown, process.CurrentState())
} }
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}
-113
View File
@@ -1,113 +0,0 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sync"
)
type ProcessGroup struct {
sync.Mutex
config Config
id string
swap bool
exclusive bool
persistent bool
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// map of current processes
processes map[string]*Process
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
}
pg := &ProcessGroup{
id: id,
config: config,
swap: groupConfig.Swap,
exclusive: groupConfig.Exclusive,
persistent: groupConfig.Persistent,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
processes: make(map[string]*Process),
}
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
return pg
}
// ProxyRequest proxies a request to the specified model
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
if !pg.HasMember(modelID) {
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
}
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
pg.lastUsedProcess = modelID
}
pg.Unlock()
}
pg.processes[modelID].ProxyRequest(writer, request)
return nil
}
func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) StopProcesses() {
pg.Lock()
defer pg.Unlock()
pg.stopProcesses()
}
// stopProcesses stops all processes in the group
func (pg *ProcessGroup) stopProcesses() {
if len(pg.processes) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
}
func (pg *ProcessGroup) Shutdown() {
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
-96
View File
@@ -1,96 +0,0 @@
package proxy
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model2"},
},
"G2": {
Swap: false,
Exclusive: true,
Members: []string{"model3", "model4"},
},
},
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
func TestProcessGroup_HasMember(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model1"))
assert.True(t, pg.HasMember("model2"))
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
tests := []string{"model1", "model2"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses()
tests := []string{"model3", "model4"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
})
}
// make sure all the processes are running
for _, process := range pg.processes {
assert.Equal(t, StateReady, process.CurrentState())
}
}
+167 -323
View File
@@ -5,9 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -15,8 +13,6 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
const ( const (
@@ -26,111 +22,61 @@ const (
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
config Config config *Config
ginEngine *gin.Engine currentProcesses map[string]*Process
logMonitor *LogMonitor
// logging ginEngine *gin.Engine
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
processGroups map[string]*ProcessGroup
} }
func New(config Config) *ProxyManager { func New(config *Config) *ProxyManager {
// set up loggers pm := &ProxyManager{
stdoutLogger := NewLogMonitorWriter(os.Stdout) config: config,
upstreamLogger := NewLogMonitorWriter(stdoutLogger) currentProcesses: make(map[string]*Process),
proxyLogger := NewLogMonitorWriter(stdoutLogger) logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
}
if config.LogRequests { if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
} }
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { // see: https://github.com/mostlygeek/llama-swap/issues/42
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
processGroups: make(map[string]*ProcessGroup),
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
// allow whatever the client requested by default c.AbortWithStatus(204)
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return return
} }
c.Next() c.Next()
@@ -147,7 +93,6 @@ func New(config Config) *ProxyManager {
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -155,16 +100,10 @@ func New(config Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/", func(c *gin.Context) { pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
c.Header("Content-Type", "text/html") c.Header("Content-Type", "text/html")
@@ -208,17 +147,27 @@ func (pm *ProxyManager) StopProcesses() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
// stop Processes in parallel pm.stopProcesses()
var wg sync.WaitGroup }
for _, processGroup := range pm.processGroups {
wg.Add(1) // for internal usage
go func(processGroup *ProcessGroup) { func (pm *ProxyManager) stopProcesses() {
defer wg.Done() if len(pm.currentProcesses) == 0 {
processGroup.stopProcesses() return
}(processGroup)
} }
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait() wg.Wait()
pm.currentProcesses = make(map[string]*Process)
} }
// Shutdown is called to shutdown all upstream processes // Shutdown is called to shutdown all upstream processes
@@ -227,44 +176,18 @@ func (pm *ProxyManager) Shutdown() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
pm.proxyLogger.Debug("Shutdown() called in proxy manager") // shutdown process in parallel
var wg sync.WaitGroup var wg sync.WaitGroup
// Send shutdown signal to all process in groups for _, process := range pm.currentProcesses {
for _, processGroup := range pm.processGroups {
wg.Add(1) wg.Add(1)
go func(processGroup *ProcessGroup) { go func(process *Process) {
defer wg.Done() defer wg.Done()
processGroup.Shutdown() process.Shutdown()
}(processGroup) }(process)
} }
wg.Wait() wg.Wait()
} }
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses()
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := []interface{}{}
for id, modelConfig := range pm.config.Models { for id, modelConfig := range pm.config.Models {
@@ -294,6 +217,82 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
} }
} }
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id") requestedModel := c.Param("model_id")
@@ -302,15 +301,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
processGroup, _, err := pm.swapProcessGroup(requestedModel) if process, err := pm.swapModel(requestedModel); err != nil {
if err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } else {
return // rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
} }
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
} }
func (pm *ProxyManager) upstreamIndex(c *gin.Context) { func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
@@ -345,148 +342,28 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return return
} }
requestedModel := gjson.GetBytes(bodyBytes, "model").String() var requestBody map[string]interface{}
if requestedModel == "" { if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) if process, err := pm.swapModel(model); err != nil {
if err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
} } else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// issue #69 allow custom model names to be sent to upstream // dechunk it as we already have all the body bytes see issue #11
useModelName := pm.config.Models[realModelName].UseModelName c.Request.Header.Del("transfer-encoding")
if useModelName != "" { c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) process.ProxyRequest(c.Writer, c.Request)
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
} }
} }
@@ -500,39 +377,6 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
} }
} }
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { func ProcessKeyName(groupName, modelName string) string {
pm.StopProcesses() return groupName + PROFILE_SPLIT_CHAR + modelName
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
}
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
for _, group := range pm.processGroups {
if group.HasMember(modelName) {
return group
}
}
return nil
} }
+8 -37
View File
@@ -9,6 +9,7 @@ import (
) )
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") { if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
@@ -27,7 +28,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
} }
} else { } else {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
history := pm.muxLogger.GetHistory() history := pm.logMonitor.GetHistory()
_, err := c.Writer.Write(history) _, err := c.Writer.Write(history)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
@@ -41,14 +42,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked") c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -61,7 +56,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped // Send history first if not skipped
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.Writer.Write(history) c.Writer.Write(history)
flusher.Flush() flusher.Flush()
@@ -90,21 +85,15 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
// Send history first if not skipped // Send history first if not skipped
_, skipHistory := c.GetQuery("no-history") _, skipHistory := c.GetQuery("no-history")
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.SSEvent("message", string(history)) c.SSEvent("message", string(history))
c.Writer.Flush() c.Writer.Flush()
@@ -122,21 +111,3 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
} }
} }
} }
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+89 -405
View File
@@ -4,11 +4,8 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -17,14 +14,13 @@ import (
) )
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
@@ -37,91 +33,58 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
} }
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
} }
func TestProxyManager_SwapMultiProcess(t *testing.T) { func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), model1: getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), model2: getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", Profiles: map[string][]string{
Groups: map[string]GroupConfig{ "test": {model1, model2},
"G1": {
Swap: true,
Exclusive: false,
Members: []string{"model1"},
},
"G2": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
}, },
})
proxy := New(config)
defer proxy.StopProcesses()
tests := []string{"model1", "model2"}
for _, requestedModel := range tests {
t.Run(requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel)
})
} }
// make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
}
// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model2"},
},
},
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
// make requests to load all models, loading model1 should not affect model2 for modelID, requestedModel := range map[string]string{
tests := []string{"model2", "model1"} "model1": profileModel1,
for _, requestedModel := range tests { "model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), modelID)
} }
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) // make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
} }
// When a request for a different model comes in ProxyManager should wait until // When a request for a different model comes in ProxyManager should wait until
@@ -131,15 +94,14 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test") t.Skip("skipping slow test")
} }
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error", }
})
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
@@ -166,9 +128,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
mu.Lock() mu.Lock()
var response map[string]string results[key] = w.Body.String()
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"]
mu.Unlock() mu.Unlock()
}(key) }(key)
@@ -184,14 +144,13 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
} }
func TestProxyManager_ListModelsHandler(t *testing.T) { func TestProxyManager_ListModelsHandler(t *testing.T) {
config := Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error",
} }
proxy := New(config) proxy := New(config)
@@ -252,6 +211,50 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
assert.Empty(t, expectedModels, "not all expected models were returned") assert.Empty(t, expectedModels, "not all expected models were returned")
} }
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -263,27 +266,23 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/" model3Config.Proxy = "http://localhost:10003/"
config := AddDefaultGroupToConfig(Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": model3Config, "model3": model3Config,
}, },
LogLevel: "error", }
Groups: map[string]GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
},
},
})
proxy := New(config) proxy := New(config)
// Start all the processes // Start all the processes
var wg sync.WaitGroup var wg sync.WaitGroup
for _, modelName := range []string{"model1", "model2", "model3"} { for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
wg.Add(1) wg.Add(1)
go func(modelName string) { go func(modelName string) {
defer wg.Done() defer wg.Done()
@@ -291,10 +290,11 @@ func TestProxyManager_Shutdown(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up // send a request to trigger the proxy to load
proxy.HandlerFunc(w, req) proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName) }(modelName)
} }
@@ -304,319 +304,3 @@ func TestProxyManager_Shutdown(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
} }
func TestProxyManager_Unload(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
// give it a bit of time to stop
<-time.After(time.Millisecond * 250)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "debug",
})
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
} `json:"running"`
}
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response RunningResponse
// Check if this is a valid JSON object.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// We should have an empty running array here.
assert.Empty(t, response.Running, "expected no running models")
})
t.Run("single model loaded", func(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// Check if we have a single array element.
assert.Len(t, response.Running, 1)
// Is this the right model?
assert.Equal(t, "model1", response.Running[0].Model)
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte("TheExpectedModel"))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": modelConfig,
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
requestedModel := "model1"
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
func TestProxyManager_Upstream(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}
-43
View File
@@ -1,43 +0,0 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
-77
View File
@@ -1,77 +0,0 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}