Compare commits

..

1 Commits

Author SHA1 Message Date
Benson Wong cebf9c4d34 increase health check to a minimum of 5 seconds 2025-02-18 17:35:52 -08:00
27 changed files with 322 additions and 1829 deletions
-23
View File
@@ -1,23 +0,0 @@
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
name: Close inactive issues
on:
schedule:
- cron: "32 1 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
-1
View File
@@ -16,7 +16,6 @@ jobs:
strategy: strategy:
matrix: matrix:
platform: [intel, cuda, vulkan, cpu, musa] platform: [intel, cuda, vulkan, cpu, musa]
fail-fast: false
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
-32
View File
@@ -1,32 +0,0 @@
# This workflow will build a golang project
name: CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Test all
run: make test-all
-12
View File
@@ -16,15 +16,3 @@ builds:
goarch: arm64 goarch: arm64
- goos: windows - goos: windows
goarch: arm64 goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+24 -49
View File
@@ -1,7 +1,4 @@
![llama-swap header image](header.jpeg) ![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap # llama-swap
@@ -14,23 +11,20 @@ Written in golang, it is very easy to install (single binary with no dependancie
- ✅ Easy to deploy: single binary with no dependencies - ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file - ✅ Easy to config: single yaml file
- ✅ On-demand model switching - ✅ On-demand model switching
- ✅ Full control over server settings per model
- ✅ OpenAI API supported endpoints: - ✅ OpenAI API supported endpoints:
- `v1/completions` - `v1/completions`
- `v1/chat/completions` - `v1/chat/completions`
- `v1/embeddings` - `v1/embeddings`
- `v1/rerank` - `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - ✅ Multiple GPU support
- ✅ llama-swap custom API endpoints
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - ✅ Docker and Podman support
-Full control over server settings per model -Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
## How does llama-swap work? ## How does llama-swap work?
@@ -67,20 +61,13 @@ models:
# Default (and minimum) is 15 seconds # Default (and minimum) is 15 seconds
healthCheckTimeout: 60 healthCheckTimeout: 60
# Valid log levels: debug, info (default), warn, error # Write HTTP logs (useful for troubleshooting), defaults to false
logLevel: info logRequests: true
# define valid model values and the upstream server start # define valid model values and the upstream server start
models: models:
"llama": "llama":
# multiline for readability cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match # where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
@@ -101,9 +88,16 @@ models:
# default: 0 = never unload model # default: 0 = never unload model
ttl: 60 ttl: 60
# `useModelName` overrides the model name in the request "qwen":
# and sends a specific name to the upstream server # environment variables to pass to the command
useModelName: "qwen:qwq" env:
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists # unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal # but they can still be requested as normal
@@ -120,7 +114,7 @@ models:
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles eliminates swapping by running multiple models at the same time # profiles make it easy to managing multi model (and gpu) configurations.
# #
# Tips: # Tips:
# - each model must be listening on a unique address and port # - each model must be listening on a unique address and port
@@ -128,20 +122,15 @@ models:
# - the profile will load and unload all models in the profile at the same time # - the profile will load and unload all models in the profile at the same time
profiles: profiles:
coding: coding:
- "qwen"
- "llama" - "llama"
- "qwen-unlisted"
``` ```
### Use Case Examples ### Advanced Examples
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints - [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases. - [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest. - [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
## Configuration
llama-s
</details> </details>
@@ -219,15 +208,9 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
# streams combined logs # streams logs
curl -Ns 'http://host/logs/stream' curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
@@ -266,11 +249,3 @@ StartLimitInterval=30
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target
``` ```
## Star History
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests # Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds # Default (and minimum): 15 seconds
healthCheckTimeout: 90 healthCheckTimeout: 15
# valid log levels: debug, info (default), warn, error # Log HTTP requests helpful for troubleshoot, defaults to False
logLevel: debug logRequests: true
models: models:
"llama": "llama":
-6
View File
@@ -38,12 +38,6 @@ else
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \ | jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}') | sort -r | head -n1 | awk -F '-' '{print $3}')
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}" CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}" CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER" echo "Building ${CONTAINER_TAG} $LS_VER"
-153
View File
@@ -1,153 +0,0 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -1,28 +0,0 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -1,26 +0,0 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
-49
View File
@@ -1,49 +0,0 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
@@ -1,51 +0,0 @@
# Restart llama-swap on config change
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
```bash
#!/bin/bash
#
# A simple watch and restart llama-swap when its configuration
# file changes. Useful for trying out configuration changes
# without manually restarting the server each time.
if [ -z "$1" ]; then
echo "Usage: $0 <path to config.yaml>"
exit 1
fi
while true; do
# Start the process again
./llama-swap-linux-amd64 -config $1 -listen :1867 &
PID=$!
echo "Started llama-swap with PID $PID"
# Wait for modifications in the specified directory or file
inotifywait -e modify "$1"
# Check if process exists before sending signal
if kill -0 $PID 2>/dev/null; then
echo "Sending SIGTERM to $PID"
kill -SIGTERM $PID
wait $PID
else
echo "Process $PID no longer exists"
fi
sleep 1
done
```
## Usage and output example
```bash
$ ./watch-and-restart.sh config.yaml
Started llama-swap with PID 495455
Setting up watches.
Watches established.
llama-swap listening on :1867
Sending SIGTERM to 495455
Shutting down llama-swap
Started llama-swap with PID 495486
Setting up watches.
Watches established.
llama-swap listening on :1867
```
+6 -10
View File
@@ -3,11 +3,7 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
@@ -19,10 +15,12 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
@@ -31,14 +29,12 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.37.0 // indirect golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
) )
-18
View File
@@ -57,16 +57,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -78,28 +68,20 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
-61
View File
@@ -12,14 +12,12 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
) )
func main() { func main() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
// Define a command-line flag for the port // Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on") port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message // Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with") responseMessage := flag.String("respond", "hi", "message to respond with")
@@ -43,70 +41,11 @@ func main() {
c.String(200, *responseMessage) c.String(200, *responseMessage)
}) })
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
})
r.POST("/v1/completions", func(c *gin.Context) { r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage) c.String(200, *responseMessage)
}) })
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
// Get the model from the form values
model := c.Request.FormValue("model")
if model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
return
}
// Get the file from the form
file, _, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
defer file.Close()
// Read the file content to get its size
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
return
}
fileSize := len(fileBytes)
// Return a JSON response with the model and transcription text including file size
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
})
})
r.GET("/slow-respond", func(c *gin.Context) { r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo") echo := c.Query("echo")
delay := c.Query("delay") delay := c.Query("delay")
-2
View File
@@ -17,7 +17,6 @@ type ModelConfig struct {
CheckEndpoint string `yaml:"checkEndpoint"` CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
@@ -27,7 +26,6 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"` Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
+71 -184
View File
@@ -12,65 +12,32 @@
flex-direction: column; flex-direction: column;
font-family: "Courier New", Courier, monospace; font-family: "Courier New", Courier, monospace;
} }
.log-container { #log-controls {
display: flex;
flex: 1;
gap: 0.5em;
margin: 0.5em; margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex; display: flex;
flex-direction: column; align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1; flex: 1;
min-width: 0;
transition: flex 0.3s ease;
} }
.log-column.minimized { #log-controls input:focus {
flex: 0.1; outline: none; /* Ensures no outline is shown when the input is focused */
max-width: 50px;
border: 1px solid #777;
color: green;
} }
.log-controls { #log-stream {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1; flex: 1;
margin: 0.5em;
padding: 1em; padding: 1em;
background: #f4f4f4; background: #f4f4f4;
overflow-y: auto; overflow-y: auto;
white-space: pre-wrap; white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; word-wrap: break-word; /* Ensures long words wrap */
min-height: 0;
} }
.regex-error { .regex-error {
background-color: #ff0000 !important; background-color: #ff0000 !important;
} }
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */ /* Dark mode styles */
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
body { body {
@@ -78,181 +45,101 @@
color: #fff; color: #fff;
} }
.log-stream { #log-stream {
background: #444; background: #444;
color: #fff; color: #fff;
} }
.log-controls input { #log-controls input {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
.log-controls button { #log-controls button {
background: #555; background: #555;
color: #fff; color: #fff;
border: 1px solid #777; border: 1px solid #777;
} }
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
} }
</style> </style>
</head> </head>
<body> <body>
<div class="log-container"> <pre id="log-stream">Waiting for logs...</pre>
<div class="log-column"> <div id="log-controls">
<h2>Proxy Logs</h2> <input type="text" id="filter-input" placeholder="regex filter">
<div class="log-controls"> <button id="clear-button">clear</button>
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div> </div>
<script> <script>
class LogStream { const logStream = document.getElementById('log-stream');
constructor(streamElement, filterInput, clearButton, endpoint) { const filterInput = document.getElementById('filter-input');
this.streamElement = streamElement; var logData = "";
this.filterInput = filterInput; let regexFilter = null;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
this.initialize(); function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
}
} }
initialize() { // poor-ai's react ¯\_(ツ)_/¯
this.filterInput.addEventListener('input', () => this.updateFilter()); function render() {
this.clearButton.addEventListener('click', () => { if (regexFilter) {
this.filterInput.value = ""; const lines = logData.split('\n');
this.regexFilter = null; const filteredLines = lines.filter(line => {
this.render(); return regexFilter === null || regexFilter.test(line);
}); });
this.setupEventSource();
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
} }
setupEventSource() { logStream.scrollTop = logStream.scrollHeight;
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
};
// Initial connection
connect();
}
render() {
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
} }
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try { try {
this.regexFilter = new RegExp(pattern); regexFilter = new RegExp(pattern);
} catch (e) { } catch (e) {
console.error("Invalid regex pattern:", e); console.error("Invalid regex pattern:", e);
this.regexFilter = null; regexFilter = null;
this.filterInput.classList.add('regex-error'); filterInput.classList.add('regex-error');
return; return
}
} else {
regexFilter = null;
} }
this.render(); render();
}
} }
// Initialize both log streams filterInput.addEventListener('input', updateFilter);
document.addEventListener('DOMContentLoaded', () => { document.getElementById('clear-button').addEventListener('click', () => {
new LogStream( filterInput.value = "";
document.getElementById('proxy-log-stream'), regexFilter = null;
document.getElementById('proxy-filter-input'), render();
document.getElementById('proxy-clear-button'),
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
});
});
}); });
setupEventSource();
updateFilter();
</script> </script>
</body> </body>
</html> </html>
-90
View File
@@ -2,21 +2,11 @@ package proxy
import ( import (
"container/ring" "container/ring"
"fmt"
"io" "io"
"os" "os"
"sync" "sync"
) )
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool clients map[chan []byte]bool
mu sync.RWMutex mu sync.RWMutex
@@ -25,10 +15,6 @@ type LogMonitor struct {
// typically this can be os.Stdout // typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels
level LogLevel
prefix string
} }
func NewLogMonitor() *LogMonitor { func NewLogMonitor() *LogMonitor {
@@ -40,8 +26,6 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool), clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo,
prefix: "",
} }
} }
@@ -110,77 +94,3 @@ func (w *LogMonitor) broadcast(msg []byte) {
} }
} }
} }
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+91 -121
View File
@@ -33,12 +33,8 @@ type Process struct {
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time lastRequestHandled time.Time
@@ -55,68 +51,54 @@ type Process struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
processLogger: processLogger, logMonitor: logMonitor,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped, state: StateStopped,
shutdownCtx: ctx, shutdownCtx: ctx,
shutdownCancel: cancel, shutdownCancel: cancel,
} }
} }
// LogMonitor returns the log monitor associated with the process. func (p *Process) setState(newState ProcessState) error {
func (p *Process) LogMonitor() *LogMonitor { // enforce valid state transitions
return p.processLogger invalidTransition := false
} if p.state == StateStopped {
// stopped -> starting
// custom error types for swapping state if newState != StateStarting {
var ( invalidTransition = true
ErrExpectedStateMismatch = errors.New("expected state mismatch") }
ErrInvalidStateTransition = errors.New("invalid state transition") } else if p.state == StateStarting {
) // starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
// swapState performs a compare and swap of the state atomically. It returns the current state invalidTransition = true
// and an error if the swap failed. }
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) { } else if p.state == StateReady {
p.stateMutex.Lock() // ready -> stopping
defer p.stateMutex.Unlock() if newState != StateStopping {
invalidTransition = true
if p.state != expectedState { }
return p.state, ErrExpectedStateMismatch } else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
} }
if !isValidTransition(p.state, newState) { if invalidTransition {
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState) //panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return p.state, ErrInvalidStateTransition return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
} }
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
p.state = newState p.state = newState
return p.state, nil return nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown:
return false // No transitions allowed from these states
}
return false
} }
func (p *Process) CurrentState() ProcessState { func (p *Process) CurrentState() ProcessState {
@@ -134,48 +116,47 @@ func (p *Process) start() error {
return fmt.Errorf("can not start(), upstream proxy missing") return fmt.Errorf("can not start(), upstream proxy missing")
} }
args, err := p.config.SanitizedCommand() // wait for the other start() to complete
if err != nil { curState := p.CurrentState()
return fmt.Errorf("unable to get sanitized command: %v", err)
if curState == StateReady {
return nil
} }
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting { if curState == StateStarting {
p.waitStarting.Wait() p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
} }
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if err := p.setState(StateStarting); err != nil {
return err
} }
p.waitStarting.Add(1) p.waitStarting.Add(1)
defer p.waitStarting.Done() defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
p.cmd = exec.Command(args[0], args[1:]...) p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.processLogger p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.processLogger p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
// Set process state to failed
if err != nil { if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil { p.setState(StateFailed)
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
@@ -210,35 +191,31 @@ func (p *Process) start() error {
) )
defer cancelHealthCheck() defer cancelHealthCheck()
// Health check loop
loop: loop:
// Ready Check loop
for { for {
select { select {
case <-checkDeadline.Done(): case <-checkDeadline.Done():
if curState, err := p.swapState(StateStarting, StateFailed); err != nil { p.setState(StateFailed)
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState) return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done(): case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown") return errors.New("health check interrupted due to shutdown")
default: default:
if err := p.checkHealthEndpoint(healthURL); err == nil { if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("Health check passed on %s", healthURL)
cancelHealthCheck() cancelHealthCheck()
break loop break loop
} else { } else {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds()) fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else { } else {
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err) fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
} }
} }
} }
<-time.After(p.healthCheckLoopInterval) <-time.After(5 * time.Second)
} }
} }
@@ -249,7 +226,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) { for range time.Tick(time.Second) {
if p.CurrentState() != StateReady { if p.state != StateReady {
return return
} }
@@ -257,8 +234,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
} }
@@ -266,28 +242,26 @@ func (p *Process) start() error {
}() }()
} }
if curState, err := p.swapState(StateStarting, StateReady); err != nil { return p.setState(StateReady)
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
} }
func (p *Process) Stop() { func (p *Process) Stop() {
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// calling Stop() when state is invalid is a no-op // calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if err := p.setState(StateStopping); err != nil {
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState) fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
return return
} }
// stop the process with a graceful exit timeout // stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if err := p.setState(StateStopped); err != nil {
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState) panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
} }
} }
@@ -295,19 +269,24 @@ func (p *Process) Stop() {
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down // is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel() p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second) p.stopCommand(5 * time.Second)
p.state = StateShutdown if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL. // If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) { func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("Process [%s] stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout() defer cancelTimeout()
@@ -317,32 +296,31 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
}() }()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { p.cmd.Process.Signal(syscall.SIGTERM)
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
}
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill() p.cmd.Process.Kill()
case err := <-sigtermNormal: case err := <-sigtermNormal:
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno) fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID) fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else { } else {
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
} }
} else { } else {
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err) fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
} }
} }
} }
@@ -374,8 +352,6 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable // prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState() currentState := p.CurrentState()
@@ -392,13 +368,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand // start the process on demand
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway) http.Error(w, errstr, http.StatusBadGateway)
return return
} }
startDuration = time.Since(beginStartTime)
} }
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
@@ -442,8 +416,4 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("Process [%s] request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
} }
-9
View File
@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
-14
View File
@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+39 -47
View File
@@ -5,6 +5,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -12,17 +13,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var (
discardLogger = NewLogMonitorWriter(io.Discard)
)
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", 5, config, discardLogger, discardLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
@@ -55,10 +52,11 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start() // are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) { func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, discardLogger, discardLogger) process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop() defer process.Stop()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -86,7 +84,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("broken", 1, config, discardLogger, discardLogger) process := NewProcess("broken", 1, config, NewLogMonitor())
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -111,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger) process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -153,7 +151,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter) assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, discardLogger, discardLogger) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@@ -171,8 +169,6 @@ func TestProcess_LowTTLValue(t *testing.T) {
} }
// issue #19 // issue #19
// This test makes sure using Process.Stop() does not affect pending HTTP
// requests. All HTTP requests in this test should complete successfully.
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping slow test") t.Skip("skipping slow test")
@@ -180,7 +176,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345" expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage) config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, discardLogger, discardLogger) process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop() defer process.Stop()
results := map[string]string{ results := map[string]string{
@@ -196,9 +192,8 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(key string) { go func(key string) {
defer wg.Done() defer wg.Done()
// send a request where simple-responder is will wait 300ms before responding // send a request that should take 5 * 200ms (1 second) to complete
// this will simulate an in-progress request. req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
@@ -214,9 +209,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
}(key) }(key)
} }
// Stop the process while requests are still being processed // stop the requests in the middle
go func() { go func() {
<-time.After(150 * time.Millisecond) <-time.After(500 * time.Millisecond)
process.Stop() process.Stop()
}() }()
@@ -227,40 +222,39 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
} }
} }
func TestProcess_SwapState(t *testing.T) { func TestSetState(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
currentState ProcessState currentState ProcessState
expectedState ProcessState
newState ProcessState newState ProcessState
expectedError error expectedError error
expectedResult ProcessState expectedResult ProcessState
}{ }{
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting}, {"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady}, {"Starting to Ready", StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed}, {"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping}, {"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping}, {"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped}, {"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown}, {"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped}, {"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting}, {"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady}, {"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady}, {"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping}, {"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed}, {"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed}, {"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown}, {"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown}, {"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger) p := &Process{
p.state = test.currentState state: test.currentState,
}
resultState, err := p.swapState(test.expectedState, test.newState) err := p.setState(test.newState)
if err != nil && test.expectedError == nil { if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil { } else if err == nil && test.expectedError != nil {
@@ -271,8 +265,8 @@ func TestProcess_SwapState(t *testing.T) {
} }
} }
if resultState != test.expectedResult { if p.state != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState) t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
} }
}) })
} }
@@ -283,6 +277,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test") t.Skip("skipping long shutdown test")
} }
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong // make a config where the healthcheck will always fail because port is wrong
@@ -290,16 +285,13 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test" config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30 healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger) process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown // start a goroutine to simulate a shutdown
var wg sync.WaitGroup var wg sync.WaitGroup
go func() { go func() {
defer wg.Done() defer wg.Done()
<-time.After(time.Millisecond * 500) <-time.After(time.Second * 2)
process.Shutdown() process.Shutdown()
}() }()
wg.Add(1) wg.Add(1)
+27 -226
View File
@@ -5,9 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -15,8 +13,6 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
const ( const (
@@ -28,47 +24,19 @@ type ProxyManager struct {
config *Config config *Config
currentProcesses map[string]*Process currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
} }
func New(config *Config) *ProxyManager { func New(config *Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: config,
currentProcesses: make(map[string]*Process), currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(), ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
} }
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// Start timer // Start timer
start := time.Now() start := time.Now()
@@ -87,8 +55,9 @@ func New(config *Config) *ProxyManager {
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
bodySize := c.Writer.Size() bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP, clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method, method,
path, path,
c.Request.Proto, c.Request.Proto,
@@ -98,26 +67,16 @@ func New(config *Config) *ProxyManager {
duration, duration,
) )
}) })
}
// see: issue: #81, #77 and #42 for CORS issues // see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint // respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
// allow whatever the client requested by default c.AbortWithStatus(204)
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return return
} }
c.Next() c.Next()
@@ -134,7 +93,6 @@ func New(config *Config) *ProxyManager {
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -142,16 +100,10 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/", func(c *gin.Context) { pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
c.Header("Content-Type", "text/html") c.Header("Content-Type", "text/html")
@@ -270,7 +222,11 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
defer pm.Unlock() defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR // Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := splitRequestedModel(requestedModel) profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
if profileName != "" { if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found { if _, found := pm.config.Profiles[profileName]; !found {
@@ -303,20 +259,19 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
requestedProcessKey := ProcessKeyName(profileName, realModelName) requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found { if process, found := pm.currentProcesses[requestedProcessKey]; found {
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
return process, nil return process, nil
} }
// stop all running models // stop all running models
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
pm.stopProcesses() pm.stopProcesses()
if profileName == "" { if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName) modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found { if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName) return nil, fmt.Errorf("could not find configuration for %s", realModelName)
} }
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger) process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID) processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process pm.currentProcesses[processKey] = process
} else { } else {
@@ -327,7 +282,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName) return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
} }
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger) process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID) processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process pm.currentProcesses[processKey] = process
} }
@@ -387,36 +342,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return return
} }
requestedModel := gjson.GetBytes(bodyBytes, "model").String() var requestBody map[string]interface{}
if requestedModel == "" { if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
} }
process, err := pm.swapModel(requestedModel) if process, err := pm.swapModel(model); err != nil {
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return return
}
// issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else { } else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11 // dechunk it as we already have all the body bytes see issue #11
@@ -424,110 +364,7 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request) process.ProxyRequest(c.Writer, c.Request)
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
} }
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Swap to the requested model
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
}
// Get profile name and model name from the requested model
profileName, modelName := splitRequestedModel(requestedModel)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// Use the modified request for proxying
process.ProxyRequest(c.Writer, modifiedReq)
} }
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -540,42 +377,6 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
} }
} }
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses()
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, process := range pm.currentProcesses {
// Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func ProcessKeyName(groupName, modelName string) string { func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName return groupName + PROFILE_SPLIT_CHAR + modelName
} }
func splitRequestedModel(requestedModel string) (string, string) {
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
}
+8 -37
View File
@@ -9,6 +9,7 @@ import (
) )
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept") accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") { if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html // Set the Content-Type header to text/html
@@ -27,7 +28,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
} }
} else { } else {
c.Header("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
history := pm.muxLogger.GetHistory() history := pm.logMonitor.GetHistory()
_, err := c.Writer.Write(history) _, err := c.Writer.Write(history)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
@@ -41,14 +42,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked") c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -61,7 +56,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped // Send history first if not skipped
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.Writer.Write(history) c.Writer.Write(history)
flusher.Flush() flusher.Flush()
@@ -90,21 +85,15 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID") ch := pm.logMonitor.Subscribe()
logger, err := pm.getLogger(logMonitorId) defer pm.logMonitor.Unsubscribe(ch)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done() notify := c.Request.Context().Done()
// Send history first if not skipped // Send history first if not skipped
_, skipHistory := c.GetQuery("no-history") _, skipHistory := c.GetQuery("no-history")
if !skipHistory { if !skipHistory {
history := logger.GetHistory() history := pm.logMonitor.GetHistory()
if len(history) != 0 { if len(history) != 0 {
c.SSEvent("message", string(history)) c.SSEvent("message", string(history))
c.Writer.Flush() c.Writer.Flush()
@@ -122,21 +111,3 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
} }
} }
} }
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
-404
View File
@@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync" "sync"
@@ -306,405 +304,3 @@ func TestProxyManager_Shutdown(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
} }
func TestProxyManager_Unload(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
}
proxy := New(config)
proc, err := proxy.swapModel("model1")
assert.NoError(t, err)
assert.NotNil(t, proc)
assert.Len(t, proxy.currentProcesses, 1)
req := httptest.NewRequest("GET", "/unload", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name
func TestProxyManager_StripProfileSlug(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
}
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
} `json:"running"`
}
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response RunningResponse
// Check if this is a valid JSON object.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// We should have an empty running array here.
assert.Empty(t, response.Running, "expected no running models")
})
t.Run("single model loaded", func(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// Check if we have a single array element.
assert.Len(t, response.Running, 1)
// Is this the right model?
assert.Equal(t, "model1", response.Running[0].Model)
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
testCases := []struct {
name string
modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tc.modelInput))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config)
defer proxy.StopProcesses()
tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
}
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
-43
View File
@@ -1,43 +0,0 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
-77
View File
@@ -1,77 +0,0 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}