Compare commits

...

44 Commits

Author SHA1 Message Date
Benson Wong b3d331da0d Properly strip profile name slug from models fixes (#62)
The profile slug in a model name, `profile:model`, is specific to
llama-swap. This strips `profile:` out of the model name request so
upstreams that expect just `model` work and do not require knowing about
the profile slug.
2025-03-09 12:41:52 -07:00
Benson Wong 62275e078d add examples to restart on config change #59 2025-03-06 10:50:29 -08:00
Benson Wong 88916059e1 add /unload to docs 2025-03-03 10:44:16 -08:00
Benson Wong 082d5d0fc5 Add /unload endpoint (#58) to unload all currently running models 2025-03-03 10:33:36 -08:00
Benson Wong 53338938bd increase health check to a minimum of 5 seconds 2025-03-03 10:04:08 -08:00
Benson Wong af653347ae Update README.md w/ starhistory graph 2025-02-27 16:43:34 -08:00
Benson Wong 1e25b44a06 add workflow_dispatch to release action 2025-02-18 17:27:43 -08:00
Benson Wong 0815bb4cc3 Add windows to goreleaser #54 2025-02-18 17:26:43 -08:00
daschiller 7187cfe52e add Windows build support to Makefile (#54) 2025-02-18 17:24:31 -08:00
Benson Wong 24089d2d9c remove "no musa container" note from README 2025-02-18 16:38:48 -08:00
Benson Wong ebabe55ff3 Delete untagged packages after build and push (#55) 2025-02-18 10:32:32 -08:00
Benson Wong 41a338297c deletion of untagged containers happen after build-and-push 2025-02-18 10:11:59 -08:00
Benson Wong 7e3353efeb add action step to remove untagged containers 2025-02-18 10:08:41 -08:00
Benson Wong 4ed58fb173 update container build action 2025-02-18 09:59:06 -08:00
Benson Wong f5a2be698d revert package src until new ggml-org has them 2025-02-15 18:23:58 -08:00
Benson Wong f5e6ec3b7a fix package src in containerfile 2025-02-15 18:20:35 -08:00
Benson Wong 3f462da146 switch package source from ggerganov to ggml-org 2025-02-15 18:18:49 -08:00
Benson Wong 48bd766536 Update README.md 2025-02-14 22:05:52 -08:00
Benson Wong 8d319da4dd improve README organization (i think...) 2025-02-14 15:59:12 -08:00
Benson Wong be7c502448 improve docs 2025-02-14 15:47:31 -08:00
Benson Wong 92336f00bf more container build fixes 2025-02-14 15:34:38 -08:00
Benson Wong ed2a50d9a6 fix bug in build-container.sh 2025-02-14 15:27:56 -08:00
Benson Wong 0acfdb9f78 update workflow to build cpu and disable musa 2025-02-14 15:26:59 -08:00
Benson Wong 96a8ea0241 add cpu docker container build 2025-02-14 15:25:45 -08:00
Benson Wong f20f2c9b7a add docs and container build improvements #43 2025-02-14 12:20:07 -08:00
Benson Wong 7a97c38828 enable parallel container built #46 2025-02-14 11:04:33 -08:00
Benson Wong 4885132565 more permissions futzing 2025-02-14 11:02:15 -08:00
Benson Wong 8b46a0b7f1 grant package:write to container workflow #46 2025-02-14 10:55:30 -08:00
Benson Wong 1b6736ec6f rename workflow for containers 2025-02-14 10:50:15 -08:00
Benson Wong ddc1ce031e fix container file name #46 2025-02-14 10:49:44 -08:00
Benson Wong 11d024bbaa just build cuda while debugging 2025-02-14 10:48:06 -08:00
Benson Wong 43e23c16dc add check for GITHUB_TOKEN #46 2025-02-14 10:47:25 -08:00
Benson Wong f9c8e763ba add execute bit on build-container.sh 2025-02-14 10:44:53 -08:00
Benson Wong d7e1bb9f7c add GITHUB_TOKEN to container build env 2025-02-14 10:43:44 -08:00
Benson Wong ab93460a8b first container code (#52) 2025-02-14 10:39:25 -08:00
Benson Wong 13d4552edc Add FreeBSD/amd64 to auto built releases (#51) 2025-02-13 16:44:31 -08:00
Benson Wong 6667e307a2 Update README.md 2025-02-08 10:28:35 -08:00
Benson Wong 7ac446e6a9 Update README.md 2025-02-08 10:26:11 -08:00
Benson Wong eab9795bcc remove panic() when cmd or process is nil 2025-02-07 14:00:32 -08:00
Benson Wong 09bdd86b54 Improve shutdown behaviour (#47) (#49)
Introduce `Process.Shutdown()` and `ProxyManager.Shutdown()`. These two function required a lot of internal process state management refactoring. A key benefit is that `Process.start()` is now interruptable. When `Shutdown()` is called it will break the long health check loop. 

State management within Process is also improved. Added `starting`, `stopping` and `shutdown` states. Additionally, introduced a simple finite state machine to manage transitions.
2025-02-05 17:19:59 -08:00
Benson Wong 85cd74a51c Improve process start and stop reliability (#38)
Refactor Process.start()/Stop() logic (#38)
- remove cmd.Wait() call in start(). This seems to conflict with the one
  in .Stop(). Removing it eliminated no child errors
- eliminate goroutines in .start() as it no longer required
2025-02-03 11:50:38 -08:00
Benson Wong 314d2f2212 remove cmd_stop configuration and functionality from PR #40 (#44)
* remove cmd_stop functionality from #40
2025-01-31 12:42:44 -08:00
Benson Wong fad25f3e11 Use client request context in proxy request (#43)
Canceled or closed HTTP requests from clients will also stop the proxied
HTTP requests to upstreamed servers.
2025-01-31 10:21:49 -08:00
Benson Wong 2c3e3e27f7 Support OPTIONS requests (#42)
Add middleware that responds with permissive OPTIONS headers
for all request paths.
2025-01-31 10:09:07 -08:00
20 changed files with 860 additions and 272 deletions
+45
View File
@@ -0,0 +1,45 @@
name: Build Containers
on:
# time has no specific meaning, trying to time it after
# the llama.cpp daily packages are published
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
schedule:
- cron: "37 5 * * *"
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
build-and-push:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [intel, cuda, vulkan, cpu, musa]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Run build-container
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} true
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
delete-untagged-containers:
needs: build-and-push
runs-on: ubuntu-latest
steps:
- uses: actions/delete-package-versions@v5
with:
package-name: 'llama-swap'
package-type: 'container'
delete-only-untagged-versions: 'true'
+3
View File
@@ -5,6 +5,9 @@ on:
tags:
- '*'
# Allows manual triggering of the workflow
workflow_dispatch:
permissions:
contents: write
+8 -1
View File
@@ -6,6 +6,13 @@ builds:
goos:
- linux
- darwin
- freebsd
- windows
goarch:
- amd64
- arm64
- arm64
ignore:
- goos: freebsd
goarch: arm64
- goos: windows
goarch: arm64
+6 -1
View File
@@ -35,6 +35,11 @@ linux:
@echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
# Build Windows binary
windows:
@echo "Building Windows binary..."
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
@@ -60,4 +65,4 @@ release:
git tag "$$new_tag";
# Phony targets
.PHONY: all clean osx linux
.PHONY: all clean mac linux windows simple-responder
+114 -33
View File
@@ -1,21 +1,10 @@
# llama-swap
![llama-swap header image](header.jpeg)
# Introduction
# llama-swap
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
## How does it work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
## Features:
@@ -29,19 +18,44 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
- `v1/embeddings`
- `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
-Multiple GPU support
- ✅ Docker Support ([#40](https://github.com/mostlygeek/llama-swap/pull/40))
- ✅ Run multiple models at once with `profiles`
-Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
-
- ✅ Manually unload models via `/unload` endpoint ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support
## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap's configuration is purposefully simple.
```yaml
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
```
<details>
<summary>But also very powerful ...</summary>
```yaml
# Seconds to wait for llama.cpp to load and be ready to serve requests
# Default (and minimum) is 15 seconds
@@ -60,8 +74,8 @@ models:
# aliases names to use this model for
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# check this path for an HTTP 200 OK before serving requests
# default: /health to match llama.cpp
@@ -88,17 +102,12 @@ models:
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
unlisted: true
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (Experimental)
# see: https://github.com/mostlygeek/llama-swap/pull/40
"dockertest":
# Docker Support (v26.1.4+ required!)
"docker-llama":
proxy: "http://127.0.0.1:9790"
# introduced to reliably stop containers
cmd_stop: docker stop -t 2 dockertest
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
@@ -117,17 +126,74 @@ profiles:
- "llama"
```
### Advanced Examples
### Use Case Examples
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
### Installation
## Configuration
llama-s
</details>
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker is the quickest way to try out llama-swap:
```
# use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
```
<details>
<summary>Docker images are nightly ...</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
</details>
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
* _Note: Windows currently untested._
1. Run the binary with `llama-swap --config path/to/config.yaml`
### Building from source
@@ -157,11 +223,18 @@ curl -Ns http://host/logs/stream | grep 'eval time'
curl -Ns 'http://host/logs/stream?no-history'
```
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
## Systemd Unit Files
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
`/etc/systemd/system/llama-swap.service`
```
[Unit]
Description=llama-swap
@@ -181,3 +254,11 @@ StartLimitInterval=30
[Install]
WantedBy=multi-user.target
```
## Star History
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+1 -8
View File
@@ -53,16 +53,9 @@ models:
--ctx-size 8192
--reranking
# EXPERIMENTAL! Docker Support
# see:
# - https://github.com/mostlygeek/llama-swap/pull/40
# - https://github.com/mostlygeek/llama-swap/issues/35
# Docker Support (v26.1.4+ required!)
"dockertest":
proxy: "http://127.0.0.1:9790"
# use this to reliably stop named containers
cmd_stop: docker stop -t 2 dockertest
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
+49
View File
@@ -0,0 +1,49 @@
#!/bin/bash
cd $(dirname "$0")
ARCH=$1
PUSH_IMAGES=${2:-false}
# List of allowed architectures
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
# Check if ARCH is in the allowed list
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
exit 1
fi
# Check if GITHUB_TOKEN is set and not empty
if [[ -z "$GITHUB_TOKEN" ]]; then
echo "Error: GITHUB_TOKEN is not set or is empty."
exit 1
fi
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
+17
View File
@@ -0,0 +1,17 @@
healthCheckTimeout: 300
logRequests: true
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
+16
View File
@@ -0,0 +1,16 @@
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
@@ -0,0 +1,51 @@
# Restart llama-swap on config change
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
```bash
#!/bin/bash
#
# A simple watch and restart llama-swap when its configuration
# file changes. Useful for trying out configuration changes
# without manually restarting the server each time.
if [ -z "$1" ]; then
echo "Usage: $0 <path to config.yaml>"
exit 1
fi
while true; do
# Start the process again
./llama-swap-linux-amd64 -config $1 -listen :1867 &
PID=$!
echo "Started llama-swap with PID $PID"
# Wait for modifications in the specified directory or file
inotifywait -e modify "$1"
# Check if process exists before sending signal
if kill -0 $PID 2>/dev/null; then
echo "Sending SIGTERM to $PID"
kill -SIGTERM $PID
wait $PID
else
echo "Process $PID no longer exists"
fi
sleep 1
done
```
## Usage and output example
```bash
$ ./watch-and-restart.sh config.yaml
Started llama-swap with PID 495455
Setting up watches.
Watches established.
llama-swap listening on :1867
Sending SIGTERM to 495455
Shutting down llama-swap
Started llama-swap with PID 495486
Setting up watches.
Watches established.
llama-swap listening on :1867
```
+4
View File
@@ -29,6 +29,10 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
+10
View File
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
+1 -1
View File
@@ -47,7 +47,7 @@ func main() {
go func() {
<-sigChan
fmt.Println("Shutting down llama-swap")
proxyManager.StopProcesses()
proxyManager.Shutdown()
os.Exit(0)
}()
+21
View File
@@ -12,12 +12,14 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with")
@@ -41,6 +43,25 @@ func main() {
c.String(200, *responseMessage)
})
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
-4
View File
@@ -11,7 +11,6 @@ import (
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmd_stop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
@@ -23,9 +22,6 @@ type ModelConfig struct {
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
func (m *ModelConfig) SanitizeCommandStop() ([]string, error) {
return SanitizeCommand(m.CmdStop)
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
-26
View File
@@ -35,11 +35,6 @@ models:
aliases:
- "m2"
checkEndpoint: "/"
docker:
cmd: docker run -p 9999:8080 --name "my_container"
cmd_stop: docker stop my_container
proxy: "http://localhost:9999"
checkEndpoint: "/health"
healthCheckTimeout: 15
profiles:
test:
@@ -61,7 +56,6 @@ profiles:
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
@@ -69,19 +63,11 @@ profiles:
},
"model2": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: nil,
CheckEndpoint: "/",
},
"docker": {
Cmd: `docker run -p 9999:8080 --name "my_container"`,
CmdStop: "docker stop my_container",
Proxy: "http://localhost:9999",
Env: nil,
CheckEndpoint: "/health",
},
},
HealthCheckTimeout: 15,
Profiles: map[string][]string{
@@ -113,18 +99,6 @@ func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_ModelConfigSanitizedCommandStop(t *testing.T) {
config := &ModelConfig{
CmdStop: `docker stop my_container \
--arg1 1
--arg2 2`,
}
args, err := config.SanitizeCommandStop()
assert.NoError(t, err)
assert.Equal(t, []string{"docker", "stop", "my_container", "--arg1", "1", "--arg2", "2"}, args)
}
func TestConfig_FindConfig(t *testing.T) {
// TODO?
+232 -181
View File
@@ -17,14 +17,19 @@ import (
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateReady ProcessState = ProcessState("ready")
StateFailed ProcessState = ProcessState("failed")
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// failed a health check on start and will not be recovered
StateFailed ProcessState = ProcessState("failed")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type Process struct {
sync.Mutex
ID string
config ModelConfig
cmd *exec.Cmd
@@ -37,9 +42,17 @@ type Process struct {
state ProcessState
inFlightRequests sync.WaitGroup
// used to block on multiple start() calls
waitStarting sync.WaitGroup
// for managing shutdown state
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
return &Process{
ID: ID,
config: config,
@@ -47,22 +60,88 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
}
}
// start the process and returns when it is ready
func (p *Process) setState(newState ProcessState) error {
// enforce valid state transitions
invalidTransition := false
if p.state == StateStopped {
// stopped -> starting
if newState != StateStarting {
invalidTransition = true
}
} else if p.state == StateStarting {
// starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateReady {
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
}
if invalidTransition {
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
}
p.state = newState
return nil
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
func (p *Process) start() error {
if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing")
}
// wait for the other start() to complete
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state == StateReady {
return nil
if err := p.setState(StateStarting); err != nil {
return err
}
if p.state == StateFailed {
return fmt.Errorf("process is in a failed state and can not be restarted")
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
@@ -77,7 +156,8 @@ func (p *Process) start() error {
err = p.cmd.Start()
if err != nil {
return err
p.setState(StateFailed)
return fmt.Errorf("start() failed: %v", err)
}
// One of three things can happen at this stage:
@@ -86,35 +166,56 @@ func (p *Process) start() error {
// 3. The health check passes
//
// only in the third case will the process be considered Ready to accept
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
defer cancelHealthCheck(nil) // clean up
cmdWaitChan := make(chan error, 1)
healthCheckChan := make(chan error, 1)
<-time.After(250 * time.Millisecond) // give process a bit of time to start
go func() {
// possible cmd exits early
cmdWaitChan <- p.cmd.Wait()
}()
checkStartTime := time.Now()
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
go func() {
<-time.After(250 * time.Millisecond) // give process a bit of time to start
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
}()
select {
case err := <-cmdWaitChan:
p.state = StateFailed
if err != nil {
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
} else {
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
cancelHealthCheck(err)
return err
case err := <-healthCheckChan:
proxyTo := p.config.Proxy
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
p.state = StateFailed
return err
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
}
checkDeadline, cancelHealthCheck := context.WithDeadline(
context.Background(),
checkStartTime.Add(maxDuration),
)
defer cancelHealthCheck()
// Health check loop
loop:
for {
select {
case <-checkDeadline.Done():
p.setState(StateFailed)
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
}
}
}
<-time.After(5 * time.Second)
}
}
@@ -141,192 +242,142 @@ func (p *Process) start() error {
}()
}
p.state = StateReady
return nil
return p.setState(StateReady)
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != StateReady {
fmt.Fprintf(p.logMonitor, "!!! Stop() called but Process State is not READY\n")
// calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
return
}
// stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second)
if err := p.setState(StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
}
}
// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) {
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n")
p.state = StateStopped
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
return
}
// Pretty sure this stopping code needs some work for windows and
// will be a source of pain in the future.
p.cmd.Process.Signal(syscall.SIGTERM)
if p.config.CmdStop != "" {
// for issue #35 to do things like `docker stop`
args, err := p.config.SanitizeCommandStop()
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill()
case err := <-sigtermNormal:
if err != nil {
fmt.Fprintf(p.logMonitor, "!!! Error sanitizing stop command: %v\n", err)
// leave the state as it is?
return
}
fmt.Fprintf(p.logMonitor, "!!! Running stop command: %s\n", strings.Join(args, " "))
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = p.logMonitor
cmd.Stderr = p.logMonitor
err = cmd.Start()
if err != nil {
fmt.Fprintf(p.logMonitor, "!!! Error running stop command: %v\n", err)
// leave the state as it is?
return
}
err = cmd.Wait()
if err != nil {
fmt.Fprintf(p.logMonitor, "!!! WARNING error waiting for stop command to complete: %v\n", err)
}
} else {
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
p.cmd.Process.Signal(syscall.SIGTERM)
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "XXX Process for %s timed out waiting to stop, sending SIGKILL to PID: %d\n", p.ID, p.cmd.Process.Pid)
p.cmd.Process.Kill()
p.cmd.Wait()
case err := <-sigtermNormal:
if err != nil {
if err.Error() != "wait: no child processes" {
// possible that simple-responder for testing is just not
// existing right, so suppress those errors.
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
if errno, ok := err.(syscall.Errno); ok {
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
}
}
}
}
p.state = StateStopped
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req)
ttl := (maxDuration - time.Since(startTime)).Seconds()
if err != nil {
// check if the context was cancelled
select {
case <-ctx.Done():
err := context.Cause(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
return err
}
default:
}
// wait a bit longer for TCP connection issues
if strings.Contains(err.Error(), "connection refused") {
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
time.Sleep(5 * time.Second)
} else {
time.Sleep(time.Second)
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
}
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
}
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// got a response but it was not an OK
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code: %d", resp.StatusCode)
}
return nil
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Add(1)
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
return
}
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
// start the process on demand
if p.CurrentState() != StateReady {
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError)
http.Error(w, errstr, http.StatusBadGateway)
return
}
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
+115 -2
View File
@@ -48,6 +48,33 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
}
}
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(reqID int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
}(i)
}
wg.Wait()
assert.Equal(t, StateReady, process.CurrentState())
}
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
@@ -58,13 +85,17 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
}
process := NewProcess("broken", 1, config, NewLogMonitor())
defer process.Stop()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "unable to start process")
w = httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
}
func TestProcess_UnloadAfterTTL(t *testing.T) {
@@ -190,3 +221,85 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
assert.Equal(t, key, result)
}
}
func TestSetState(t *testing.T) {
tests := []struct {
name string
currentState ProcessState
newState ProcessState
expectedError error
expectedResult ProcessState
}{
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := &Process{
state: test.currentState,
}
err := p.setState(test.newState)
if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil {
t.Errorf("Expected error: %v, but got none", test.expectedError)
} else if err != nil && test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
}
}
if p.state != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
}
})
}
}
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping long shutdown test")
}
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
// start a goroutine to simulate a shutdown
var wg sync.WaitGroup
go func() {
defer wg.Done()
<-time.After(time.Second * 2)
process.Shutdown()
}()
wg.Add(1)
// start the process, this is a blocking call
err := process.start()
wg.Wait()
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
+72 -15
View File
@@ -13,6 +13,8 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -69,6 +71,19 @@ func New(config *Config) *ProxyManager {
})
}
// see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
@@ -91,6 +106,8 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
@@ -143,13 +160,38 @@ func (pm *ProxyManager) stopProcesses() {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
process.Stop()
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Stop()
}(process)
}
wg.Wait()
pm.currentProcesses = make(map[string]*Process)
}
// Shutdown is called to shutdown all upstream processes
// when llama-swap is shutting down.
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
// shutdown process in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id, modelConfig := range pm.config.Models {
@@ -184,11 +226,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
@@ -304,21 +342,26 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
if process, err := pm.swapModel(model); err != nil {
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
} else {
// strip
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
@@ -339,6 +382,20 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses()
c.String(http.StatusOK, "OK")
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
func splitRequestedModel(requestedModel string) (string, string) {
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
}
+95
View File
@@ -254,3 +254,98 @@ func TestProxyManager_ProfileNonMember(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/"
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
model2Config.Proxy = "http://localhost:10002/"
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/"
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": model3Config,
},
}
proxy := New(config)
// Start all the processes
var wg sync.WaitGroup
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
wg.Add(1)
go func(modelName string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
// send a request to trigger the proxy to load
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName)
}
go func() {
<-time.After(time.Second)
proxy.Shutdown()
}()
wg.Wait()
}
func TestProxyManager_Unload(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
}
proxy := New(config)
proc, err := proxy.swapModel("model1")
assert.NoError(t, err)
assert.NotNil(t, proc)
assert.Len(t, proxy.currentProcesses, 1)
req := httptest.NewRequest("GET", "/unload", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name
func TestProxyManager_StripProfileSlug(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
}