Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 | |||
| 3a1e9f81f1 | |||
| 72c883f36c | |||
| 1b04d034cf | |||
| 2e45f5692a | |||
| c97b80bdfe | |||
| ae3ef9bc39 | |||
| db6715bec3 | |||
| da5d9e8a6a | |||
| 84b667ca7a | |||
| 29657106fc |
@@ -0,0 +1,45 @@
|
|||||||
|
name: Build Containers
|
||||||
|
|
||||||
|
on:
|
||||||
|
# time has no specific meaning, trying to time it after
|
||||||
|
# the llama.cpp daily packages are published
|
||||||
|
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||||
|
schedule:
|
||||||
|
- cron: "37 5 * * *"
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
platform: [intel, cuda, vulkan, cpu, musa]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Log in to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v2
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Run build-container
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||||
|
|
||||||
|
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||||
|
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||||
|
delete-untagged-containers:
|
||||||
|
needs: build-and-push
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/delete-package-versions@v5
|
||||||
|
with:
|
||||||
|
package-name: 'llama-swap'
|
||||||
|
package-type: 'container'
|
||||||
|
delete-only-untagged-versions: 'true'
|
||||||
@@ -5,6 +5,9 @@ on:
|
|||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,13 @@ builds:
|
|||||||
goos:
|
goos:
|
||||||
- linux
|
- linux
|
||||||
- darwin
|
- darwin
|
||||||
|
- freebsd
|
||||||
|
- windows
|
||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
ignore:
|
||||||
|
- goos: freebsd
|
||||||
|
goarch: arm64
|
||||||
|
- goos: windows
|
||||||
|
goarch: arm64
|
||||||
@@ -35,6 +35,11 @@ linux:
|
|||||||
@echo "Building Linux binary..."
|
@echo "Building Linux binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
|
||||||
|
# Build Windows binary
|
||||||
|
windows:
|
||||||
|
@echo "Building Windows binary..."
|
||||||
|
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||||
|
|
||||||
# for testing proxy.Process
|
# for testing proxy.Process
|
||||||
simple-responder:
|
simple-responder:
|
||||||
@echo "Building simple responder"
|
@echo "Building simple responder"
|
||||||
@@ -60,4 +65,4 @@ release:
|
|||||||
git tag "$$new_tag";
|
git tag "$$new_tag";
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean osx linux
|
.PHONY: all clean mac linux windows simple-responder
|
||||||
|
|||||||
@@ -1,44 +1,69 @@
|
|||||||
# llama-swap
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
# Introduction
|
# llama-swap
|
||||||
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
|
||||||
|
|
||||||
Features:
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
|
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||||
|
|
||||||
|
## Features:
|
||||||
|
|
||||||
- ✅ Easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
- ✅ Easy to config: single yaml file
|
- ✅ Easy to config: single yaml file
|
||||||
- ✅ On-demand model switching
|
- ✅ On-demand model switching
|
||||||
- ✅ Full control over server settings per model
|
- ✅ Full control over server settings per model
|
||||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
- ✅ OpenAI API supported endpoints:
|
||||||
|
- `v1/completions`
|
||||||
|
- `v1/chat/completions`
|
||||||
|
- `v1/embeddings`
|
||||||
|
- `v1/rerank`
|
||||||
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- ✅ Multiple GPU support
|
- ✅ Multiple GPU support
|
||||||
- ✅ Run multiple models at once with `profiles`
|
- ✅ Docker and Podman support
|
||||||
|
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||||
- ✅ Remote log monitoring at `/log`
|
- ✅ Remote log monitoring at `/log`
|
||||||
- ✅ Automatic unloading of models from GPUs after timeout
|
- ✅ Automatic unloading of models from GPUs after timeout
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabblyAPI, etc)
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
|
|
||||||
## Releases
|
## How does llama-swap work?
|
||||||
|
|
||||||
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||||
|
|
||||||
### Building from source
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
1. Install golang for your system
|
|
||||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
|
||||||
1. `make clean all`
|
|
||||||
1. Binaries will be in `build/` subdirectory
|
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration is purposefully simple.
|
llama-swap's configuration is purposefully simple.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>But also very powerful ...</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||||
# Default (and minimum) is 15 seconds
|
# Default (and minimum) is 15 seconds
|
||||||
healthCheckTimeout: 60
|
healthCheckTimeout: 60
|
||||||
|
|
||||||
|
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||||
|
logRequests: true
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
@@ -77,8 +102,17 @@ models:
|
|||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
# unlisted models do not show up in /v1/models or /upstream lists
|
||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
|
||||||
unlisted: true
|
unlisted: true
|
||||||
|
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
|
# Docker Support (v26.1.4+ required!)
|
||||||
|
"docker-llama":
|
||||||
|
proxy: "http://127.0.0.1:9790"
|
||||||
|
cmd: >
|
||||||
|
docker run --name dockertest
|
||||||
|
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||||
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||||
#
|
#
|
||||||
@@ -92,19 +126,78 @@ profiles:
|
|||||||
- "llama"
|
- "llama"
|
||||||
```
|
```
|
||||||
|
|
||||||
**Guides and examples**
|
### Advanced Examples
|
||||||
|
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||||
|
|
||||||
## Installation
|
</details>
|
||||||
|
|
||||||
|
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
|
|
||||||
|
Docker is the quickest way to try out llama-swap:
|
||||||
|
|
||||||
|
```
|
||||||
|
# use CPU inference
|
||||||
|
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||||
|
|
||||||
|
|
||||||
|
# qwen2.5 0.5B
|
||||||
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer no-key" \
|
||||||
|
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||||
|
jq -r '.choices[0].message.content'
|
||||||
|
|
||||||
|
|
||||||
|
# SmolLM2 135M
|
||||||
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer no-key" \
|
||||||
|
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||||
|
jq -r '.choices[0].message.content'
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Docker images are nightly ...</summary>
|
||||||
|
|
||||||
|
They include:
|
||||||
|
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||||
|
- ROCm disabled until fixed in llama.cpp container
|
||||||
|
|
||||||
|
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||||
|
|
||||||
|
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
|
-v /path/to/models:/models \
|
||||||
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
|
ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||||
|
|
||||||
|
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||||
* _Note: Windows currently untested._
|
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||||
|
|
||||||
|
### Building from source
|
||||||
|
|
||||||
|
1. Install golang for your system
|
||||||
|
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||||
|
1. `make clean all`
|
||||||
|
1. Binaries will be in `build/` subdirectory
|
||||||
|
|
||||||
## Monitoring Logs
|
## Monitoring Logs
|
||||||
|
|
||||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||||
@@ -125,11 +218,18 @@ curl -Ns http://host/logs/stream | grep 'eval time'
|
|||||||
curl -Ns 'http://host/logs/stream?no-history'
|
curl -Ns 'http://host/logs/stream?no-history'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Do I need to use llama.cpp's server (llama-server)?
|
||||||
|
|
||||||
|
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||||
|
|
||||||
|
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||||
|
|
||||||
## Systemd Unit Files
|
## Systemd Unit Files
|
||||||
|
|
||||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||||
|
|
||||||
`/etc/systemd/system/llama-swap.service`
|
`/etc/systemd/system/llama-swap.service`
|
||||||
|
|
||||||
```
|
```
|
||||||
[Unit]
|
[Unit]
|
||||||
Description=llama-swap
|
Description=llama-swap
|
||||||
@@ -149,3 +249,11 @@ StartLimitInterval=30
|
|||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||||
|
</picture>
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
# Default (and minimum): 15 seconds
|
# Default (and minimum): 15 seconds
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 15
|
||||||
|
|
||||||
|
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||||
|
logRequests: true
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
cmd: >
|
||||||
@@ -50,6 +53,14 @@ models:
|
|||||||
--ctx-size 8192
|
--ctx-size 8192
|
||||||
--reranking
|
--reranking
|
||||||
|
|
||||||
|
# Docker Support (v26.1.4+ required!)
|
||||||
|
"dockertest":
|
||||||
|
proxy: "http://127.0.0.1:9790"
|
||||||
|
cmd: >
|
||||||
|
docker run --name dockertest
|
||||||
|
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||||
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
"simple":
|
"simple":
|
||||||
# example of setting environment variables
|
# example of setting environment variables
|
||||||
|
|||||||
Executable
+49
@@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
cd $(dirname "$0")
|
||||||
|
|
||||||
|
ARCH=$1
|
||||||
|
PUSH_IMAGES=${2:-false}
|
||||||
|
|
||||||
|
# List of allowed architectures
|
||||||
|
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||||
|
|
||||||
|
# Check if ARCH is in the allowed list
|
||||||
|
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||||
|
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if GITHUB_TOKEN is set and not empty
|
||||||
|
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||||
|
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# the most recent llama-swap tag
|
||||||
|
# have to strip out the 'v' due to .tar.gz file naming
|
||||||
|
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||||
|
|
||||||
|
if [ "$ARCH" == "cpu" ]; then
|
||||||
|
# cpu only containers just use the latest available
|
||||||
|
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||||
|
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||||
|
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||||
|
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||||
|
docker push ${CONTAINER_LATEST}
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||||
|
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||||
|
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||||
|
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||||
|
|
||||||
|
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||||
|
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||||
|
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||||
|
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||||
|
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||||
|
docker push ${CONTAINER_TAG}
|
||||||
|
docker push ${CONTAINER_LATEST}
|
||||||
|
fi
|
||||||
|
fi
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
healthCheckTimeout: 300
|
||||||
|
logRequests: true
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
ARG BASE_TAG=server-cuda
|
||||||
|
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||||
|
|
||||||
|
# has to be after the FROM
|
||||||
|
ARG LS_VER=89
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
RUN \
|
||||||
|
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||||
|
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||||
|
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||||
|
|
||||||
|
COPY config.example.yaml /app/config.yaml
|
||||||
|
|
||||||
|
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||||
|
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||||
@@ -33,7 +33,7 @@ require (
|
|||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0 // indirect
|
golang.org/x/crypto v0.31.0 // indirect
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.33.0 // indirect
|
||||||
golang.org/x/sys v0.28.0 // indirect
|
golang.org/x/sys v0.28.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
|||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
|
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||||
|
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
@@ -39,6 +41,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
proxyManager := proxy.New(config)
|
||||||
|
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
<-sigChan
|
||||||
|
fmt.Println("Shutting down llama-swap")
|
||||||
|
proxyManager.Shutdown()
|
||||||
|
os.Exit(0)
|
||||||
|
}()
|
||||||
|
|
||||||
fmt.Println("llama-swap listening on " + *listenStr)
|
fmt.Println("llama-swap listening on " + *listenStr)
|
||||||
if err := proxyManager.Run(*listenStr); err != nil {
|
if err := proxyManager.Run(*listenStr); err != nil {
|
||||||
fmt.Printf("Server error: %v\n", err)
|
fmt.Printf("Server error: %v\n", err)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
LogRequests bool `yaml:"logRequests"`
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
Models map[string]ModelConfig `yaml:"models"`
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>llama-swap</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>llama-swap</h1>
|
||||||
|
<p>
|
||||||
|
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
||||||
|
</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
+107
-15
@@ -12,42 +12,134 @@
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
font-family: "Courier New", Courier, monospace;
|
font-family: "Courier New", Courier, monospace;
|
||||||
}
|
}
|
||||||
|
#log-controls {
|
||||||
|
margin: 0.5em;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between; /* Spaces out elements evenly */
|
||||||
|
}
|
||||||
|
#log-controls input {
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
#log-controls input:focus {
|
||||||
|
outline: none; /* Ensures no outline is shown when the input is focused */
|
||||||
|
}
|
||||||
#log-stream {
|
#log-stream {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
margin: 1em;
|
margin: 0.5em;
|
||||||
padding: 10px;
|
padding: 1em;
|
||||||
background: #f4f4f4;
|
background: #f4f4f4;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
white-space: pre-wrap; /* Ensures line wrapping */
|
white-space: pre-wrap; /* Ensures line wrapping */
|
||||||
word-wrap: break-word; /* Ensures long words wrap */
|
word-wrap: break-word; /* Ensures long words wrap */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.regex-error {
|
||||||
|
background-color: #ff0000 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dark mode styles */
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
body {
|
||||||
|
background-color: #333;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
|
#log-stream {
|
||||||
|
background: #444;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
|
#log-controls input {
|
||||||
|
background: #555;
|
||||||
|
color: #fff;
|
||||||
|
border: 1px solid #777;
|
||||||
|
}
|
||||||
|
|
||||||
|
#log-controls button {
|
||||||
|
background: #555;
|
||||||
|
color: #fff;
|
||||||
|
border: 1px solid #777;
|
||||||
|
}
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<pre id="log-stream">Waiting for logs...
|
<pre id="log-stream">Waiting for logs...</pre>
|
||||||
</pre>
|
<div id="log-controls">
|
||||||
|
<input type="text" id="filter-input" placeholder="regex filter">
|
||||||
|
<button id="clear-button">clear</button>
|
||||||
|
</div>
|
||||||
<script>
|
<script>
|
||||||
// Establish an EventSource connection to the SSE endpoint
|
const logStream = document.getElementById('log-stream');
|
||||||
|
const filterInput = document.getElementById('filter-input');
|
||||||
|
var logData = "";
|
||||||
|
let regexFilter = null;
|
||||||
|
|
||||||
|
function setupEventSource() {
|
||||||
if (typeof(EventSource) !== "undefined") {
|
if (typeof(EventSource) !== "undefined") {
|
||||||
const eventSource = new EventSource("/logs/streamSSE");
|
const eventSource = new EventSource("/logs/streamSSE");
|
||||||
|
|
||||||
eventSource.onmessage = function(event) {
|
eventSource.onmessage = function(event) {
|
||||||
// Append the new log message to the <pre> element
|
logData += event.data;
|
||||||
const logStream = document.getElementById('log-stream');
|
render()
|
||||||
|
|
||||||
logStream.textContent += event.data;
|
|
||||||
|
|
||||||
// Auto-scroll to the bottom
|
|
||||||
logStream.scrollTop = logStream.scrollHeight;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
eventSource.onerror = function(err) {
|
eventSource.onerror = function(err) {
|
||||||
console.error("EventSource failed:", err);
|
logData = "EventSource failed: " + err.message;
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
console.error("SSE not supported by this browser.");
|
logData = "SSE Not supported by this browser."
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// poor-ai's react ¯\_(ツ)_/¯
|
||||||
|
function render() {
|
||||||
|
if (regexFilter) {
|
||||||
|
const lines = logData.split('\n');
|
||||||
|
const filteredLines = lines.filter(line => {
|
||||||
|
return regexFilter === null || regexFilter.test(line);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (filteredLines.length > 0) {
|
||||||
|
logStream.textContent = filteredLines.join('\n') + '\n';
|
||||||
|
} else {
|
||||||
|
logStream.textContent = "";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logStream.textContent = logData;
|
||||||
|
}
|
||||||
|
|
||||||
|
logStream.scrollTop = logStream.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateFilter() {
|
||||||
|
const pattern = filterInput.value.trim();
|
||||||
|
filterInput.classList.remove('regex-error');
|
||||||
|
if (pattern) {
|
||||||
|
try {
|
||||||
|
regexFilter = new RegExp(pattern);
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Invalid regex pattern:", e);
|
||||||
|
regexFilter = null;
|
||||||
|
filterInput.classList.add('regex-error');
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
regexFilter = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
render();
|
||||||
|
}
|
||||||
|
|
||||||
|
filterInput.addEventListener('input', updateFilter);
|
||||||
|
document.getElementById('clear-button').addEventListener('click', () => {
|
||||||
|
filterInput.value = "";
|
||||||
|
regexFilter = null;
|
||||||
|
render();
|
||||||
|
});
|
||||||
|
setupEventSource();
|
||||||
|
updateFilter();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "embed"
|
||||||
|
|
||||||
|
//go:embed html
|
||||||
|
var htmlFiles embed.FS
|
||||||
|
|
||||||
|
func getHTMLFile(path string) ([]byte, error) {
|
||||||
|
return htmlFiles.ReadFile("html/" + path)
|
||||||
|
}
|
||||||
+209
-126
@@ -18,13 +18,18 @@ type ProcessState string
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
StateStopped ProcessState = ProcessState("stopped")
|
StateStopped ProcessState = ProcessState("stopped")
|
||||||
|
StateStarting ProcessState = ProcessState("starting")
|
||||||
StateReady ProcessState = ProcessState("ready")
|
StateReady ProcessState = ProcessState("ready")
|
||||||
|
StateStopping ProcessState = ProcessState("stopping")
|
||||||
|
|
||||||
|
// failed a health check on start and will not be recovered
|
||||||
StateFailed ProcessState = ProcessState("failed")
|
StateFailed ProcessState = ProcessState("failed")
|
||||||
|
|
||||||
|
// process is shutdown and will not be restarted
|
||||||
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
sync.Mutex
|
|
||||||
|
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
@@ -37,9 +42,17 @@ type Process struct {
|
|||||||
state ProcessState
|
state ProcessState
|
||||||
|
|
||||||
inFlightRequests sync.WaitGroup
|
inFlightRequests sync.WaitGroup
|
||||||
|
|
||||||
|
// used to block on multiple start() calls
|
||||||
|
waitStarting sync.WaitGroup
|
||||||
|
|
||||||
|
// for managing shutdown state
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
@@ -47,22 +60,88 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
|
|||||||
logMonitor: logMonitor,
|
logMonitor: logMonitor,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
|
shutdownCtx: ctx,
|
||||||
|
shutdownCancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the process and returns when it is ready
|
func (p *Process) setState(newState ProcessState) error {
|
||||||
|
// enforce valid state transitions
|
||||||
|
invalidTransition := false
|
||||||
|
if p.state == StateStopped {
|
||||||
|
// stopped -> starting
|
||||||
|
if newState != StateStarting {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
} else if p.state == StateStarting {
|
||||||
|
// starting -> ready | failed | stopping
|
||||||
|
if newState != StateReady && newState != StateFailed && newState != StateStopping {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
} else if p.state == StateReady {
|
||||||
|
// ready -> stopping
|
||||||
|
if newState != StateStopping {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
} else if p.state == StateStopping {
|
||||||
|
// stopping -> stopped | shutdown
|
||||||
|
if newState != StateStopped && newState != StateShutdown {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
} else if p.state == StateFailed || p.state == StateShutdown {
|
||||||
|
invalidTransition = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if invalidTransition {
|
||||||
|
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
|
||||||
|
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.state = newState
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) CurrentState() ProcessState {
|
||||||
|
p.stateMutex.RLock()
|
||||||
|
defer p.stateMutex.RUnlock()
|
||||||
|
return p.state
|
||||||
|
}
|
||||||
|
|
||||||
|
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||||
|
// it is a private method because starting is automatic but stopping can be called
|
||||||
|
// at any time.
|
||||||
func (p *Process) start() error {
|
func (p *Process) start() error {
|
||||||
|
|
||||||
|
if p.config.Proxy == "" {
|
||||||
|
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the other start() to complete
|
||||||
|
curState := p.CurrentState()
|
||||||
|
|
||||||
|
if curState == StateReady {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if curState == StateStarting {
|
||||||
|
p.waitStarting.Wait()
|
||||||
|
|
||||||
|
if state := p.CurrentState(); state != StateReady {
|
||||||
|
return fmt.Errorf("start() failed current state: %v", state)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
p.stateMutex.Lock()
|
p.stateMutex.Lock()
|
||||||
defer p.stateMutex.Unlock()
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
if p.state == StateReady {
|
if err := p.setState(StateStarting); err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.state == StateFailed {
|
p.waitStarting.Add(1)
|
||||||
return fmt.Errorf("process is in a failed state and can not be restarted")
|
defer p.waitStarting.Done()
|
||||||
}
|
|
||||||
|
|
||||||
args, err := p.config.SanitizedCommand()
|
args, err := p.config.SanitizedCommand()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -77,7 +156,8 @@ func (p *Process) start() error {
|
|||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
p.setState(StateFailed)
|
||||||
|
return fmt.Errorf("start() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// One of three things can happen at this stage:
|
// One of three things can happen at this stage:
|
||||||
@@ -86,35 +166,56 @@ func (p *Process) start() error {
|
|||||||
// 3. The health check passes
|
// 3. The health check passes
|
||||||
//
|
//
|
||||||
// only in the third case will the process be considered Ready to accept
|
// only in the third case will the process be considered Ready to accept
|
||||||
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
|
||||||
defer cancelHealthCheck(nil) // clean up
|
|
||||||
cmdWaitChan := make(chan error, 1)
|
|
||||||
healthCheckChan := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// possible cmd exits early
|
|
||||||
cmdWaitChan <- p.cmd.Wait()
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||||
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
checkStartTime := time.Now()
|
||||||
case err := <-cmdWaitChan:
|
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||||
p.state = StateFailed
|
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||||
} else {
|
if checkEndpoint != "none" {
|
||||||
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
// keep default behaviour
|
||||||
|
if checkEndpoint == "" {
|
||||||
|
checkEndpoint = "/health"
|
||||||
}
|
}
|
||||||
cancelHealthCheck(err)
|
|
||||||
return err
|
proxyTo := p.config.Proxy
|
||||||
case err := <-healthCheckChan:
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.state = StateFailed
|
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||||
return err
|
}
|
||||||
|
|
||||||
|
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
||||||
|
context.Background(),
|
||||||
|
checkStartTime.Add(maxDuration),
|
||||||
|
)
|
||||||
|
defer cancelHealthCheck()
|
||||||
|
|
||||||
|
// Health check loop
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-checkDeadline.Done():
|
||||||
|
p.setState(StateFailed)
|
||||||
|
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
|
||||||
|
case <-p.shutdownCtx.Done():
|
||||||
|
return errors.New("health check interrupted due to shutdown")
|
||||||
|
default:
|
||||||
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
|
cancelHealthCheck()
|
||||||
|
break loop
|
||||||
|
} else {
|
||||||
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
|
endTime, _ := checkDeadline.Deadline()
|
||||||
|
ttl := time.Until(endTime)
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
<-time.After(5 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,166 +236,148 @@ func (p *Process) start() error {
|
|||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
p.state = StateReady
|
return p.setState(StateReady)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
|
||||||
p.stateMutex.Lock()
|
p.stateMutex.Lock()
|
||||||
defer p.stateMutex.Unlock()
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
if p.state != StateReady {
|
// calling Stop() when state is invalid is a no-op
|
||||||
|
if err := p.setState(StateStopping); err != nil {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
// stop the process with a graceful exit timeout
|
||||||
// this situation should never happen... but if it does just update the state
|
p.stopCommand(5 * time.Second)
|
||||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
|
||||||
p.state = StateStopped
|
if err := p.setState(StateStopped); err != nil {
|
||||||
return
|
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Pretty sure this stopping code needs some work for windows and
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
// will be a source of pain in the future.
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
|
// is in the state of starting, it will cancel it and shut it down
|
||||||
|
func (p *Process) Shutdown() {
|
||||||
|
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
|
||||||
|
p.shutdownCancel()
|
||||||
|
|
||||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
p.stateMutex.Lock()
|
||||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
defer p.stateMutex.Unlock()
|
||||||
defer cancel()
|
p.setState(StateStopping)
|
||||||
|
|
||||||
|
// 5 seconds to stop the process
|
||||||
|
p.stopCommand(5 * time.Second)
|
||||||
|
if err := p.setState(StateShutdown); err != nil {
|
||||||
|
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
|
||||||
|
}
|
||||||
|
p.setState(StateShutdown)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
|
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||||
|
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||||
|
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||||
|
defer cancelTimeout()
|
||||||
|
|
||||||
sigtermNormal := make(chan error, 1)
|
sigtermNormal := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
sigtermNormal <- p.cmd.Wait()
|
sigtermNormal <- p.cmd.Wait()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
||||||
p.cmd.Process.Kill()
|
p.cmd.Process.Kill()
|
||||||
p.cmd.Wait()
|
|
||||||
case err := <-sigtermNormal:
|
case err := <-sigtermNormal:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() != "wait: no child processes" {
|
if errno, ok := err.(syscall.Errno); ok {
|
||||||
// possible that simple-responder for testing is just not
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
||||||
// existing right, so suppress those errors.
|
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
||||||
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.state = StateStopped
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) CurrentState() ProcessState {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
p.stateMutex.RLock()
|
|
||||||
defer p.stateMutex.RUnlock()
|
|
||||||
return p.state
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
client := &http.Client{
|
||||||
if p.config.Proxy == "" {
|
Timeout: 500 * time.Millisecond,
|
||||||
return fmt.Errorf("no upstream available to check /health")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
|
||||||
|
|
||||||
if checkEndpoint == "none" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep default behaviour
|
|
||||||
if checkEndpoint == "" {
|
|
||||||
checkEndpoint = "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
|
||||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
|
||||||
defer cancel()
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// check if the context was cancelled
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
err := context.Cause(ctx)
|
|
||||||
if !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait a bit longer for TCP connection issues
|
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
|
||||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
} else {
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ttl < 0 {
|
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
|
// got a response but it was not an OK
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
if ttl < 0 {
|
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
// prevent new requests from being made while stopping or irrecoverable
|
||||||
|
currentState := p.CurrentState()
|
||||||
|
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
||||||
|
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.inFlightRequests.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.lastRequestHandled = time.Now()
|
||||||
p.inFlightRequests.Done()
|
p.inFlightRequests.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// start the process on demand
|
||||||
if p.CurrentState() != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
http.Error(w, errstr, http.StatusInternalServerError)
|
http.Error(w, errstr, http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
+143
-4
@@ -48,6 +48,33 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||||
|
// are all handled successfully, even though they all may ask for the process to .start()
|
||||||
|
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||||
|
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
expectedMessage := "testing91931"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("test-process", 5, config, logMonitor)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(reqID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||||
|
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
}
|
||||||
|
|
||||||
// test that the automatic start returns the expected error type
|
// test that the automatic start returns the expected error type
|
||||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
@@ -58,16 +85,19 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||||
defer process.Stop()
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// test that the process unloads after the TTL
|
|
||||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping long auto unload TTL test")
|
t.Skip("skipping long auto unload TTL test")
|
||||||
@@ -79,7 +109,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
config.UnloadAfter = 3 // seconds
|
config.UnloadAfter = 3 // seconds
|
||||||
assert.Equal(t, 3, config.UnloadAfter)
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// this should take 4 seconds
|
// this should take 4 seconds
|
||||||
@@ -111,6 +141,33 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
assert.Equal(t, StateStopped, process.CurrentState())
|
assert.Equal(t, StateStopped, process.CurrentState())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_LowTTLValue(t *testing.T) {
|
||||||
|
if true { // change this code to run this ...
|
||||||
|
t.Skip("skipping test, edit process_test.go to run it ")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := getTestSimpleResponderConfig("fast_ttl")
|
||||||
|
assert.Equal(t, 0, config.UnloadAfter)
|
||||||
|
config.UnloadAfter = 1 // second
|
||||||
|
assert.Equal(t, 1, config.UnloadAfter)
|
||||||
|
|
||||||
|
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
t.Logf("Waiting before sending request %d", i)
|
||||||
|
time.Sleep(1500 * time.Millisecond)
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("echo=test_%d", i)
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// issue #19
|
// issue #19
|
||||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
@@ -164,3 +221,85 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
assert.Equal(t, key, result)
|
assert.Equal(t, key, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetState(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentState ProcessState
|
||||||
|
newState ProcessState
|
||||||
|
expectedError error
|
||||||
|
expectedResult ProcessState
|
||||||
|
}{
|
||||||
|
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
|
||||||
|
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
|
||||||
|
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
|
||||||
|
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
|
||||||
|
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
|
||||||
|
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
|
||||||
|
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
|
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
|
||||||
|
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
|
||||||
|
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
|
||||||
|
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
|
||||||
|
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
|
||||||
|
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
|
||||||
|
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
|
||||||
|
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
|
||||||
|
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
p := &Process{
|
||||||
|
state: test.currentState,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.setState(test.newState)
|
||||||
|
if err != nil && test.expectedError == nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
} else if err == nil && test.expectedError != nil {
|
||||||
|
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||||
|
} else if err != nil && test.expectedError != nil {
|
||||||
|
if err.Error() != test.expectedError.Error() {
|
||||||
|
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.state != test.expectedResult {
|
||||||
|
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long shutdown test")
|
||||||
|
}
|
||||||
|
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
expectedMessage := "testing91931"
|
||||||
|
|
||||||
|
// make a config where the healthcheck will always fail because port is wrong
|
||||||
|
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||||
|
config.Proxy = "http://localhost:9998/test"
|
||||||
|
|
||||||
|
healthCheckTTLSeconds := 30
|
||||||
|
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||||
|
|
||||||
|
// start a goroutine to simulate a shutdown
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-time.After(time.Second * 2)
|
||||||
|
process.Shutdown()
|
||||||
|
}()
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
// start the process, this is a blocking call
|
||||||
|
err := process.start()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||||
|
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||||
|
}
|
||||||
|
|||||||
+136
-18
@@ -2,7 +2,6 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"embed"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -20,15 +19,6 @@ const (
|
|||||||
PROFILE_SPLIT_CHAR = ":"
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed html/favicon.ico
|
|
||||||
var faviconData []byte
|
|
||||||
|
|
||||||
//go:embed html/logs.html
|
|
||||||
var logsHTML []byte
|
|
||||||
|
|
||||||
// make sure embed is kept there by the IDE auto-package importer
|
|
||||||
var _ = embed.FS{}
|
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -46,6 +36,52 @@ func New(config *Config) *ProxyManager {
|
|||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.LogRequests {
|
||||||
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
// Start timer
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// capture these because /upstream/:model rewrites them in c.Next()
|
||||||
|
clientIP := c.ClientIP()
|
||||||
|
method := c.Request.Method
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
|
// Process request
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// Stop timer
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
statusCode := c.Writer.Status()
|
||||||
|
bodySize := c.Writer.Size()
|
||||||
|
|
||||||
|
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
||||||
|
clientIP,
|
||||||
|
time.Now().Format("2006-01-02 15:04:05"),
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
c.Request.Proto,
|
||||||
|
statusCode,
|
||||||
|
bodySize,
|
||||||
|
c.Request.UserAgent(),
|
||||||
|
duration,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||||
|
// respond with permissive OPTIONS for any endpoint
|
||||||
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
if c.Request.Method == "OPTIONS" {
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
|
c.AbortWithStatus(204)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
@@ -55,6 +91,9 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// Support audio/speech endpoint
|
||||||
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||||
|
|
||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
@@ -65,8 +104,31 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
|
// Set the Content-Type header to text/html
|
||||||
|
c.Header("Content-Type", "text/html")
|
||||||
|
|
||||||
|
// Write the embedded HTML content to the response
|
||||||
|
htmlData, err := getHTMLFile("index.html")
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = c.Writer.Write(htmlData)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||||
c.Data(http.StatusOK, "image/x-icon", faviconData)
|
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
||||||
|
c.Data(http.StatusOK, "image/x-icon", data)
|
||||||
|
} else {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
@@ -96,13 +158,38 @@ func (pm *ProxyManager) stopProcesses() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stop Processes in parallel
|
||||||
|
var wg sync.WaitGroup
|
||||||
for _, process := range pm.currentProcesses {
|
for _, process := range pm.currentProcesses {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
process.Stop()
|
process.Stop()
|
||||||
|
}(process)
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
pm.currentProcesses = make(map[string]*Process)
|
pm.currentProcesses = make(map[string]*Process)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shutdown is called to shutdown all upstream processes
|
||||||
|
// when llama-swap is shutting down.
|
||||||
|
func (pm *ProxyManager) Shutdown() {
|
||||||
|
pm.Lock()
|
||||||
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
// shutdown process in parallel
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
process.Shutdown()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := []interface{}{}
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
@@ -127,7 +214,7 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Encode the data as JSON and write it to the response writer
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -155,6 +242,21 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if model is part of the profile
|
||||||
|
if profileName != "" {
|
||||||
|
found := false
|
||||||
|
for _, item := range pm.config.Profiles[profileName] {
|
||||||
|
if item == realModelName {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
// exit early when already running, otherwise stop everything and swap
|
||||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||||
|
|
||||||
@@ -197,12 +299,12 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
requestedModel := c.Param("model_id")
|
requestedModel := c.Param("model_id")
|
||||||
|
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("model id required in path"))
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
} else {
|
} else {
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
c.Request.URL.Path = c.Param("upstreamPath")
|
||||||
@@ -238,22 +340,23 @@ func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
|||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody map[string]interface{}
|
var requestBody map[string]interface{}
|
||||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model, ok := requestBody["model"].(string)
|
model, ok := requestBody["model"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("missing or invalid 'model' key"))
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(model); err != nil {
|
if process, err := pm.swapModel(model); err != nil {
|
||||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
@@ -266,6 +369,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||||
|
acceptHeader := c.GetHeader("Accept")
|
||||||
|
|
||||||
|
if strings.Contains(acceptHeader, "application/json") {
|
||||||
|
c.JSON(statusCode, gin.H{"error": message})
|
||||||
|
} else {
|
||||||
|
c.String(statusCode, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
|
pm.StopProcesses()
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
}
|
||||||
|
|
||||||
func ProcessKeyName(groupName, modelName string) string {
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,9 +16,14 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
|||||||
c.Header("Content-Type", "text/html")
|
c.Header("Content-Type", "text/html")
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
// Write the embedded HTML content to the response
|
||||||
_, err := c.Writer.Write(logsHTML)
|
logsHTML, err := getHTMLFile("logs.html")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = c.Writer.Write(logsHTML)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -43,7 +48,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
notify := c.Request.Context().Done()
|
notify := c.Request.Context().Done()
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,11 +58,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
if !skipHistory {
|
if !skipHistory {
|
||||||
history := pm.logMonitor.GetHistory()
|
history := pm.logMonitor.GetHistory()
|
||||||
if len(history) != 0 {
|
if len(history) != 0 {
|
||||||
_, err := c.Writer.Write(history)
|
c.Writer.Write(history)
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,7 +69,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
case msg := <-ch:
|
case msg := <-ch:
|
||||||
_, err := c.Writer.Write(msg)
|
_, err := c.Writer.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
// just break the loop if we can't write for some reason
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|||||||
@@ -210,3 +210,119 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
// Ensure all expected models were returned
|
// Ensure all expected models were returned
|
||||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||||
|
|
||||||
|
model1 := "path1/model1"
|
||||||
|
model2 := "path2/model2"
|
||||||
|
|
||||||
|
profileMemberName := ProcessKeyName("test", model1)
|
||||||
|
profileNonMemberName := ProcessKeyName("test", model2)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
model1: getTestSimpleResponderConfig("model1"),
|
||||||
|
model2: getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {model1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
// actual member of profile
|
||||||
|
{
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "model1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// actual model, but non-member will 404
|
||||||
|
{
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_Shutdown(t *testing.T) {
|
||||||
|
// make broken model configurations
|
||||||
|
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||||
|
model1Config.Proxy = "http://localhost:10001/"
|
||||||
|
|
||||||
|
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
||||||
|
model2Config.Proxy = "http://localhost:10002/"
|
||||||
|
|
||||||
|
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||||
|
model3Config.Proxy = "http://localhost:10003/"
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2", "model3"},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": model1Config,
|
||||||
|
"model2": model2Config,
|
||||||
|
"model3": model3Config,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
|
||||||
|
// Start all the processes
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(modelName string) {
|
||||||
|
defer wg.Done()
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// send a request to trigger the proxy to load
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||||
|
//fmt.Println(w.Code, w.Body.String())
|
||||||
|
}(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-time.After(time.Second)
|
||||||
|
proxy.Shutdown()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_Unload(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
proc, err := proxy.swapModel("model1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, proc)
|
||||||
|
|
||||||
|
assert.Len(t, proxy.currentProcesses, 1)
|
||||||
|
req := httptest.NewRequest("GET", "/unload", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
assert.Len(t, proxy.currentProcesses, 0)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user