Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b | |||
| 671c1a5a7b | |||
| 52c0196e0f | |||
| 3201a68a04 | |||
| 3ac94ad20e | |||
| 60355bf74a | |||
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c |
@@ -0,0 +1,23 @@
|
||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "32 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,46 @@
|
||||
name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
delete-untagged-containers:
|
||||
needs: build-and-push
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/delete-package-versions@v5
|
||||
with:
|
||||
package-name: 'llama-swap'
|
||||
package-type: 'container'
|
||||
delete-only-untagged-versions: 'true'
|
||||
@@ -0,0 +1,32 @@
|
||||
# This workflow will build a golang project
|
||||
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -5,6 +5,9 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
|
||||
+20
-1
@@ -6,6 +6,25 @@ builds:
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- freebsd
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: freebsd
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
|
||||
# use zip format for windows
|
||||
archives:
|
||||
- id: default
|
||||
format: tar.gz
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds_info:
|
||||
group: root
|
||||
owner: root
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
@@ -35,6 +35,11 @@ linux:
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
# Build Windows binary
|
||||
windows:
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
@@ -60,4 +65,4 @@ release:
|
||||
git tag "$$new_tag";
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean osx linux
|
||||
.PHONY: all clean mac linux windows simple-responder
|
||||
|
||||
@@ -1,46 +1,69 @@
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
# llama-swap
|
||||
|
||||

|
||||
|
||||
# Introduction
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
|
||||
|
||||
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
|
||||
|
||||
## How does it work?
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
- `/log` - remote log monitoring
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Full control over server settings per model
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration is purposefully simple.
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>But also very powerful ...</summary>
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
@@ -52,15 +75,22 @@ logRequests: true
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
@@ -73,22 +103,15 @@ models:
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
# `useModelName` overrides the model name in the request
|
||||
# and sends a specific name to the upstream server
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"docker-llama":
|
||||
@@ -99,7 +122,7 @@ models:
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
# profiles eliminates swapping by running multiple models at the same time
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
@@ -107,21 +130,78 @@ models:
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
### Use Case Examples
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||
|
||||
### Installation
|
||||
## Configuration
|
||||
|
||||
llama-s
|
||||
|
||||
</details>
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
|
||||
```
|
||||
# use CPU inference
|
||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||
|
||||
|
||||
# qwen2.5 0.5B
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
|
||||
|
||||
# SmolLM2 135M
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are nightly ...</summary>
|
||||
|
||||
They include:
|
||||
|
||||
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||
- ROCm disabled until fixed in llama.cpp container
|
||||
|
||||
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||
|
||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||
|
||||
```
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
|
||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
|
||||
### Building from source
|
||||
@@ -151,11 +231,18 @@ curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Systemd Unit Files
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
@@ -175,3 +262,11 @@ StartLimitInterval=30
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Star History
|
||||
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
</picture>
|
||||
|
||||
Executable
+55
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the latest available
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,17 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
@@ -0,0 +1,16 @@
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=89
|
||||
|
||||
WORKDIR /app
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
|
||||
COPY config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,51 @@
|
||||
# Restart llama-swap on config change
|
||||
|
||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#
|
||||
# A simple watch and restart llama-swap when its configuration
|
||||
# file changes. Useful for trying out configuration changes
|
||||
# without manually restarting the server each time.
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <path to config.yaml>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while true; do
|
||||
# Start the process again
|
||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||
PID=$!
|
||||
echo "Started llama-swap with PID $PID"
|
||||
|
||||
# Wait for modifications in the specified directory or file
|
||||
inotifywait -e modify "$1"
|
||||
|
||||
# Check if process exists before sending signal
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo "Sending SIGTERM to $PID"
|
||||
kill -SIGTERM $PID
|
||||
wait $PID
|
||||
else
|
||||
echo "Process $PID no longer exists"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
## Usage and output example
|
||||
|
||||
```bash
|
||||
$ ./watch-and-restart.sh config.yaml
|
||||
Started llama-swap with PID 495455
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
Sending SIGTERM to 495455
|
||||
Shutting down llama-swap
|
||||
Started llama-swap with PID 495486
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
```
|
||||
@@ -3,7 +3,11 @@ module github.com/mostlygeek/llama-swap
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -15,12 +19,10 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
@@ -29,12 +31,14 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -68,20 +78,28 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
+1
-1
@@ -47,7 +47,7 @@ func main() {
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("Shutting down llama-swap")
|
||||
proxyManager.StopProcesses()
|
||||
proxyManager.Shutdown()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
@@ -41,11 +43,70 @@ func main() {
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != *expectedModel {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
// issue #41
|
||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||
// Parse the multipart form
|
||||
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the model from the form values
|
||||
model := c.Request.FormValue("model")
|
||||
|
||||
if model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the file from the form
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the file content to get its size
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
fileSize := len(fileBytes)
|
||||
|
||||
// Return a JSON response with the model and transcription text including file size
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
@@ -17,6 +17,7 @@ type ModelConfig struct {
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
|
||||
+240
-154
@@ -17,19 +17,26 @@ import (
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// failed a health check on start and will not be recovered
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
sync.Mutex
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
healthCheckTimeout int
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandled time.Time
|
||||
|
||||
@@ -37,31 +44,84 @@ type Process struct {
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing shutdown state
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
state: StateStopped,
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
shutdownCtx: ctx,
|
||||
shutdownCancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// start the process and returns when it is ready
|
||||
func (p *Process) start() error {
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
)
|
||||
|
||||
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||
// and an error if the swap failed.
|
||||
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state == StateReady {
|
||||
return nil
|
||||
if p.state != expectedState {
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if p.state == StateFailed {
|
||||
return fmt.Errorf("process is in a failed state and can not be restarted")
|
||||
if !isValidTransition(p.state, newState) {
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
// Helper function to encapsulate transition rules
|
||||
func isValidTransition(from, to ProcessState) bool {
|
||||
switch from {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateFailed || to == StateStopping
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateFailed, StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
@@ -69,6 +129,28 @@ func (p *Process) start() error {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||
if err == ErrExpectedStateMismatch {
|
||||
// already starting, just wait for it to complete and expect
|
||||
// it to be be in the Ready start after. If not, return an error
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
if state := p.CurrentState(); state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("processes was in state %v when start() was called", curState)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||
}
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.logMonitor
|
||||
p.cmd.Stderr = p.logMonitor
|
||||
@@ -76,8 +158,15 @@ func (p *Process) start() error {
|
||||
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
return err
|
||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
err, curState, swapErr,
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
}
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
@@ -86,35 +175,59 @@ func (p *Process) start() error {
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
||||
defer cancelHealthCheck(nil) // clean up
|
||||
cmdWaitChan := make(chan error, 1)
|
||||
healthCheckChan := make(chan error, 1)
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
go func() {
|
||||
// possible cmd exits early
|
||||
cmdWaitChan <- p.cmd.Wait()
|
||||
}()
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
go func() {
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdWaitChan:
|
||||
p.state = StateFailed
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
||||
} else {
|
||||
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
cancelHealthCheck(err)
|
||||
return err
|
||||
case err := <-healthCheckChan:
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
p.state = StateFailed
|
||||
return err
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
||||
context.Background(),
|
||||
checkStartTime.Add(maxDuration),
|
||||
)
|
||||
defer cancelHealthCheck()
|
||||
|
||||
loop:
|
||||
// Ready Check loop
|
||||
for {
|
||||
select {
|
||||
case <-checkDeadline.Done():
|
||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
||||
} else {
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
case <-p.shutdownCtx.Done():
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
default:
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
cancelHealthCheck()
|
||||
break loop
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
endTime, _ := checkDeadline.Deadline()
|
||||
ttl := time.Until(endTime)
|
||||
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +238,7 @@ func (p *Process) start() error {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.state != StateReady {
|
||||
if p.CurrentState() != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -141,154 +254,127 @@ func (p *Process) start() error {
|
||||
}()
|
||||
}
|
||||
|
||||
p.state = StateReady
|
||||
return nil
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != StateReady {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() called but Process State is not READY\n")
|
||||
// calling Stop() when state is invalid is a no-op
|
||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
// this situation should never happen... but if it does just update the state
|
||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n")
|
||||
p.state = StateStopped
|
||||
return
|
||||
}
|
||||
// stop the process with a graceful exit timeout
|
||||
p.stopCommand(5 * time.Second)
|
||||
|
||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down
|
||||
func (p *Process) Shutdown() {
|
||||
p.shutdownCancel()
|
||||
p.stopCommand(5 * time.Second)
|
||||
p.state = StateShutdown
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||
defer cancelTimeout()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "XXX Process for %s timed out waiting to stop, sending SIGKILL to PID: %d\n", p.ID, p.cmd.Process.Pid)
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.state = StateStopped
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusInternalServerError)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
+127
-6
@@ -48,6 +48,33 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
@@ -58,13 +85,17 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
@@ -138,6 +169,8 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
@@ -161,8 +194,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request that should take 5 * 200ms (1 second) to complete
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
@@ -178,9 +212,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
}(key)
|
||||
}
|
||||
|
||||
// stop the requests in the middle
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(500 * time.Millisecond)
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
@@ -190,3 +224,90 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := &Process{
|
||||
state: test.currentState,
|
||||
}
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
+210
-26
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -74,9 +77,9 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
c.AbortWithStatus(204)
|
||||
c.Header("Access-Control-Allow-Methods", "*")
|
||||
c.Header("Access-Control-Allow-Headers", "*")
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -93,6 +96,7 @@ func New(config *Config) *ProxyManager {
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
@@ -104,6 +108,10 @@ func New(config *Config) *ProxyManager {
|
||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
|
||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
@@ -156,13 +164,38 @@ func (pm *ProxyManager) stopProcesses() {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
process.Stop()
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Stop()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
}
|
||||
|
||||
// Shutdown is called to shutdown all upstream processes
|
||||
// when llama-swap is shutting down.
|
||||
func (pm *ProxyManager) Shutdown() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// shutdown process in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
@@ -197,11 +230,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
@@ -317,29 +346,148 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(model); err != nil {
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
if process.config.UseModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
if profileName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get model parameter from the form
|
||||
requestedModel := c.Request.FormValue("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||
return
|
||||
}
|
||||
|
||||
// Swap to the requested model
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get profile name and model name from the requested model
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
for _, value := range values {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
if process.config.UseModelName != "" {
|
||||
fieldValue = process.config.UseModelName
|
||||
} else if profileName != "" {
|
||||
fieldValue = modelName
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||
return
|
||||
}
|
||||
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy all files from the original request
|
||||
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||
for _, fileHeader := range fileHeaders {
|
||||
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
file.Close()
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||
return
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer to finalize the form
|
||||
if err := multipartWriter.Close(); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new request with the reconstructed form data
|
||||
modifiedReq, err := http.NewRequestWithContext(
|
||||
c.Request.Context(),
|
||||
c.Request.Method,
|
||||
c.Request.URL.String(),
|
||||
&requestBuffer,
|
||||
)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the headers from the original request
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// Use the modified request for proxying
|
||||
process.ProxyRequest(c.Writer, modifiedReq)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||
@@ -352,6 +500,42 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses()
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.Header("Content-Type", "application/json")
|
||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||
|
||||
for _, process := range pm.currentProcesses {
|
||||
|
||||
// Append the process ID and State (multiple entries if profiles are being used).
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// Put the results under the `running` key.
|
||||
response := gin.H{
|
||||
"running": runningProcesses,
|
||||
}
|
||||
|
||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
}
|
||||
|
||||
func splitRequestedModel(requestedModel string) (string, string) {
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
return profileName, modelName
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
@@ -254,3 +256,388 @@ func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
|
||||
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
||||
model2Config.Proxy = "http://localhost:10002/"
|
||||
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2", "model3"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Start all the processes
|
||||
var wg sync.WaitGroup
|
||||
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||
wg.Add(1)
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||
//fmt.Println(w.Code, w.Body.String())
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-time.After(time.Second)
|
||||
proxy.Shutdown()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proc, err := proxy.swapModel("model1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proc)
|
||||
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
req := httptest.NewRequest("GET", "/unload", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
assert.Len(t, proxy.currentProcesses, 0)
|
||||
}
|
||||
|
||||
// issue 62, strip profile slug from model name
|
||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "ok")
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Shared configuration
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
}
|
||||
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
// Create proxy once for all tests
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
t.Run("no models loaded", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// Check if this is a valid JSON object.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// We should have an empty running array here.
|
||||
assert.Empty(t, response.Running, "expected no running models")
|
||||
})
|
||||
|
||||
t.Run("single model loaded", func(t *testing.T) {
|
||||
// Load just a model.
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Simulate browser call for the `/running` endpoint.
|
||||
req = httptest.NewRequest("GET", "/running", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// Check if we have a single array element.
|
||||
assert.Len(t, response.Running, 1)
|
||||
|
||||
// Is this the right model?
|
||||
assert.Equal(t, "model1", response.Running[0].Model)
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
})
|
||||
|
||||
t.Run("multiple models via profile", func(t *testing.T) {
|
||||
// Load more than one model.
|
||||
for _, model := range []string{"model1", "model2"} {
|
||||
profileModel := ProcessKeyName("test", model)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Simulate the browser call.
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// The JSON response must be valid.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// The response should contain 2 models.
|
||||
assert.Len(t, response.Running, 2)
|
||||
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
}
|
||||
|
||||
// Iterate through the models and check their states as well.
|
||||
for _, entry := range response.Running {
|
||||
_, exists := expectedModels[entry.Model]
|
||||
assert.True(t, exists, "unexpected model %s", entry.Model)
|
||||
assert.Equal(t, "ready", entry.State)
|
||||
delete(expectedModels, entry.Model)
|
||||
}
|
||||
|
||||
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
modelInput string
|
||||
expectModel string
|
||||
}{
|
||||
{
|
||||
name: "With Profile Prefix",
|
||||
modelInput: "test:TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
||||
},
|
||||
{
|
||||
name: "Without Profile Prefix",
|
||||
modelInput: "TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Should remain the same
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tc.modelInput))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectModel, response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expectedProfile string
|
||||
expectedModel string
|
||||
}{
|
||||
{"no profile", "gpt-4", "", "gpt-4"},
|
||||
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
||||
{"only profile", "profile1:", "profile1", ""},
|
||||
{"empty model", ":gpt-4", "", "gpt-4"},
|
||||
{"empty profile", ":", "", ""},
|
||||
{"no split char", "gpt-4", "", "gpt-4"},
|
||||
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
||||
if profileName != tt.expectedProfile {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
if modelName != tt.expectedModel {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||
func TestProxyManager_UseModelName(t *testing.T) {
|
||||
|
||||
upstreamModelName := "upstreamModel"
|
||||
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1"},
|
||||
},
|
||||
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
requestedModel string
|
||||
}{
|
||||
{"useModelName over rides requested model", "model1"},
|
||||
{"useModelName over rides requested profile:model", "test:model1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tt.requestedModel))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user