Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 | |||
| 3a1e9f81f1 | |||
| 72c883f36c | |||
| 1b04d034cf | |||
| 2e45f5692a | |||
| c97b80bdfe | |||
| ae3ef9bc39 | |||
| db6715bec3 | |||
| da5d9e8a6a | |||
| 84b667ca7a | |||
| 29657106fc | |||
| 9c8860471e | |||
| 9b4e3f307e | |||
| 6fe37c3abf | |||
| 7f45493a37 | |||
| 891f6a5b5a | |||
| 7183f6b43d | |||
| d89bfeb441 | |||
| 9a0c6bed40 | |||
| d6ca535939 | |||
| 27302c0c02 | |||
| d4e22cceaa | |||
| 4c94927658 | |||
| a955a4a5c0 | |||
| 22d3f1a4f9 | |||
| e2443251ad | |||
| 5fbd53c616 | |||
| 97dae50dc4 | |||
| cb978f760f | |||
| 387f0ef6c4 | |||
| 18c134624d | |||
| da2326bdc7 | |||
| da46545630 | |||
| 04b4760e7e | |||
| 9fc5d5b5eb | |||
| cf82b3c633 | |||
| e363f8f498 | |||
| c9629cf3a2 | |||
| 50426935a4 | |||
| 2fceb78e8d | |||
| 9a81c53664 | |||
| 716d37de82 | |||
| 73ad85ea69 | |||
| 533162ce6a | |||
| ba39ed4c18 | |||
| 21f54f96c2 | |||
| 7eec51f3f2 | |||
| 5021e0f299 | |||
| c9233d2c9a | |||
| a33ac6f8fb | |||
| 401aa88949 | |||
| e9e88fd229 | |||
| c3b4bb1684 | |||
| e5c909ddf7 | |||
| 36a31f450f | |||
| a8e5ee13b9 | |||
| 5944a86e86 |
@@ -0,0 +1,23 @@
|
||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "32 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,46 @@
|
||||
name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
delete-untagged-containers:
|
||||
needs: build-and-push
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/delete-package-versions@v5
|
||||
with:
|
||||
package-name: 'llama-swap'
|
||||
package-type: 'container'
|
||||
delete-only-untagged-versions: 'true'
|
||||
@@ -0,0 +1,32 @@
|
||||
# This workflow will build a golang project
|
||||
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -5,6 +5,9 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
@@ -30,4 +33,4 @@ jobs:
|
||||
version: '~> v2'
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
+2
-1
@@ -2,4 +2,5 @@
|
||||
.env
|
||||
build/
|
||||
dist/
|
||||
.vscode
|
||||
.vscode
|
||||
.DS_Store
|
||||
|
||||
+8
-1
@@ -6,6 +6,13 @@ builds:
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- freebsd
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: freebsd
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
@@ -2,6 +2,16 @@
|
||||
APP_NAME = llama-swap
|
||||
BUILD_DIR = build
|
||||
|
||||
# Get the current Git hash
|
||||
GIT_HASH := $(shell git rev-parse --short HEAD)
|
||||
ifneq ($(shell git status --porcelain),)
|
||||
# There are untracked changes
|
||||
GIT_HASH := $(GIT_HASH)+
|
||||
endif
|
||||
|
||||
# Capture the current build date in RFC3339 format
|
||||
BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Default target: Builds binaries for both OSX and Linux
|
||||
all: mac linux simple-responder
|
||||
|
||||
@@ -10,26 +20,49 @@ clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -short -v ./proxy
|
||||
|
||||
test-all:
|
||||
go test -v ./proxy
|
||||
|
||||
# Build OSX binary
|
||||
mac:
|
||||
@echo "Building Mac binary..."
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
|
||||
# Build Linux binary
|
||||
linux:
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
# for testing things
|
||||
# Build Windows binary
|
||||
windows:
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
go build -o $(BUILD_DIR)/simple-responder misc/simple-responder/simple-responder.go
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
mkdir -p $(BUILD_DIR)
|
||||
|
||||
# Create a new release tag
|
||||
release:
|
||||
@echo "Checking for unstaged changes..."
|
||||
@if [ -n "$(shell git status --porcelain)" ]; then \
|
||||
echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
# Get the highest tag in v{number} format, increment it, and create a new tag
|
||||
@highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \
|
||||
new_tag="v$$(( $${highest_tag#v} + 1 ))"; \
|
||||
echo "tagging new version: $$new_tag"; \
|
||||
git tag "$$new_tag";
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean osx linux
|
||||
.PHONY: all clean mac linux windows simple-responder
|
||||
|
||||
@@ -1,44 +1,93 @@
|
||||
# llama-swap
|
||||
|
||||

|
||||
|
||||
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models on demand. So let's swap the server on demand instead!
|
||||
# llama-swap
|
||||
|
||||
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
- ✅ easy to deploy: single binary with no dependencies
|
||||
- ✅ full control over llama-server's startup settings
|
||||
- ✅ ❤️ for users who are rely on llama.cpp for LLM inference
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- ✅ Manually unload models via `/unload` endpoint ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Docker and Podman support
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration purposefully simple.
|
||||
llama-swap's configuration is purposefully simple.
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>But also very powerful ...</summary>
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
|
||||
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||
logRequests: true
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# where to reach the server started by cmd
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases model names to use this configuration for
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# wait for this path to return an HTTP 200 before serving requests
|
||||
# defaults to /health to match llama.cpp
|
||||
#
|
||||
# use "none" to skip endpoint checking. This may cause requests to fail
|
||||
# until the server is ready
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||
# until the model is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# automatically unload the model after this many seconds
|
||||
# ttl values must be a value greater than 0
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
@@ -48,42 +97,144 @@ models:
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
unlisted: true
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
```
|
||||
|
||||
## Installation
|
||||
### Use Case Examples
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||
|
||||
## Configuration
|
||||
|
||||
llama-s
|
||||
|
||||
</details>
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
|
||||
```
|
||||
# use CPU inference
|
||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||
|
||||
|
||||
# qwen2.5 0.5B
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
|
||||
|
||||
# SmolLM2 135M
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are nightly ...</summary>
|
||||
|
||||
They include:
|
||||
|
||||
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||
- ROCm disabled until fixed in llama.cpp container
|
||||
|
||||
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||
|
||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||
|
||||
```
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
|
||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
The `/logs` endpoint is available to monitor what llama-swap is doing. It will send the last 10KB of logs. Useful for monitoring the output of llama-server. It also supports streaming of logs.
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
|
||||
Usage:
|
||||
Of course, CLI access is also supported:
|
||||
|
||||
```
|
||||
# basic, sends up to the last 10KB of logs
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# add `stream` to stream new logs as they come in
|
||||
curl -Ns 'http://host/logs?stream'
|
||||
# streams logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
|
||||
# add `skip` to skip history (only useful if used with stream)
|
||||
curl -Ns 'http://host/logs?stream&skip'
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# will output nothing :)
|
||||
curl -Ns 'http://host/logs?skip'
|
||||
# skips history and just streams new log entries
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Systemd Unit Files
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
@@ -104,8 +255,10 @@ StartLimitInterval=30
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
## Star History
|
||||
|
||||
1. Install golang for your system
|
||||
1. run `make clean all`
|
||||
1. binaries will be built into `build/` directory
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
</picture>
|
||||
|
||||
+55
-7
@@ -2,13 +2,16 @@
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 15
|
||||
|
||||
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
models/llama-server-osx
|
||||
--port 8999
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
--port 9001
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
proxy: http://127.0.0.1:9001
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
@@ -17,12 +20,48 @@ models:
|
||||
# check this path for a HTTP 200 response for the server to be ready
|
||||
checkEndpoint: /health
|
||||
|
||||
# unload model after 5 seconds
|
||||
ttl: 5
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9002
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
|
||||
# Embedding example with Nomic
|
||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||
"nomic":
|
||||
proxy: http://127.0.0.1:9005
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9005
|
||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||
--ctx-size 8192
|
||||
--batch-size 8192
|
||||
--rope-scaling yarn
|
||||
--rope-freq-scale 0.75
|
||||
-ngl 99
|
||||
--embeddings
|
||||
|
||||
# Reranking example with bge-reranker
|
||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||
"bge-reranker":
|
||||
proxy: http://127.0.0.1:9006
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9006
|
||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"dockertest":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
env:
|
||||
@@ -30,6 +69,7 @@ models:
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
# until the upstream server is ready for traffic
|
||||
@@ -39,6 +79,14 @@ models:
|
||||
"broken":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
unlisted: true
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
Executable
+55
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the latest available
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,17 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
@@ -0,0 +1,16 @@
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=89
|
||||
|
||||
WORKDIR /app
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
|
||||
COPY config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,6 @@
|
||||
# Example Configs and Use Cases
|
||||
|
||||
A collections of usecases and examples for getting the most out of llama-swap.
|
||||
|
||||
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
@@ -0,0 +1,123 @@
|
||||
# Optimizing Code Generation with llama-swap
|
||||
|
||||
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
|
||||
|
||||
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
|
||||
|
||||
**Benchmark Scenarios**
|
||||
|
||||
Three scenarios are tested:
|
||||
|
||||
- 3090-only: Just the main model on the 3090
|
||||
- 3090-with-draft: the main and draft models on the 3090
|
||||
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
|
||||
|
||||
**Available Devices**
|
||||
|
||||
Use the following command to list available devices IDs for the configuration:
|
||||
|
||||
```
|
||||
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||
ggml_cuda_init: found 4 CUDA devices:
|
||||
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Available devices:
|
||||
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
|
||||
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
|
||||
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||
```
|
||||
|
||||
**Configuration**
|
||||
|
||||
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"3090-only":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn
|
||||
--slots
|
||||
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--ctx-size 32768
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
|
||||
"3090-with-draft":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
# --ctx-size 28500 max that can fit on 3090 after draft model
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn
|
||||
--slots
|
||||
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA0
|
||||
|
||||
--ctx-size 28500
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
|
||||
"3090-P40-draft":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA1
|
||||
|
||||
--ctx-size 32768
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
```
|
||||
|
||||
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
|
||||
|
||||
|
||||
**Running the Benchmark**
|
||||
|
||||
To run the benchmark, execute the following commands:
|
||||
|
||||
1. `llama-swap -config benchmark-config.yaml`
|
||||
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
|
||||
|
||||
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
|
||||
|
||||
**Results (tokens/second)**
|
||||
|
||||
| model | python | typescript | swift |
|
||||
|-----------------|--------|------------|-------|
|
||||
| 3090-only | 34.03 | 34.01 | 34.01 |
|
||||
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
|
||||
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
|
||||
|
||||
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
|
||||
|
||||
Happy coding!
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
|
||||
# It was created to test the effects of speculative decoding and the various draft settings on performance.
|
||||
#
|
||||
# Writing code with a low temperature seems to provide fairly consistent logic.
|
||||
#
|
||||
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
|
||||
# Example: ./benchmark.sh http://localhost:8080 model1 model2
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <url> <model1> [model2 ...]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
url=$1; shift
|
||||
|
||||
echo "model,python,typescript,swift"
|
||||
|
||||
for model in "$@"; do
|
||||
|
||||
echo -n "$model,"
|
||||
|
||||
for lang in "python" "typescript" "swift"; do
|
||||
# expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548
|
||||
# (Dec 3rd/2024)
|
||||
time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second)
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
time="error"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$lang" != "swift" ]; then
|
||||
printf "%0.2f tps," $time
|
||||
else
|
||||
printf "%0.2f tps\n" $time
|
||||
fi
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,51 @@
|
||||
# Restart llama-swap on config change
|
||||
|
||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#
|
||||
# A simple watch and restart llama-swap when its configuration
|
||||
# file changes. Useful for trying out configuration changes
|
||||
# without manually restarting the server each time.
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <path to config.yaml>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while true; do
|
||||
# Start the process again
|
||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||
PID=$!
|
||||
echo "Started llama-swap with PID $PID"
|
||||
|
||||
# Wait for modifications in the specified directory or file
|
||||
inotifywait -e modify "$1"
|
||||
|
||||
# Check if process exists before sending signal
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo "Sending SIGTERM to $PID"
|
||||
kill -SIGTERM $PID
|
||||
wait $PID
|
||||
else
|
||||
echo "Process $PID no longer exists"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
## Usage and output example
|
||||
|
||||
```bash
|
||||
$ ./watch-and-restart.sh config.yaml
|
||||
Started llama-swap with PID 495455
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
Sending SIGTERM to 495455
|
||||
Shutting down llama-swap
|
||||
Started llama-swap with PID 495486
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
```
|
||||
@@ -0,0 +1,124 @@
|
||||
# Speculative Decoding
|
||||
|
||||
Speculative decoding can significantly improve the tokens per second. However, this comes at the cost of increased VRAM usage for the draft model. The examples provided are based on a server with three P40s and one 3090.
|
||||
|
||||
## Coding Use Case
|
||||
|
||||
This example uses Qwen2.5 Coder 32B with the 0.5B model as a draft. A quantization of Q8_0 was chosen for the draft model, as quantization has a greater impact on smaller models.
|
||||
|
||||
The models used are:
|
||||
|
||||
* [Bartowski Qwen2.5-Coder-32B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF)
|
||||
* [Bartowski Qwen2.5-Coder-0.5B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF)
|
||||
|
||||
The llama-swap configuration is as follows:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen-coder-32b-q4":
|
||||
# main model on 3090, draft on P40 #1
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-be0e35
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--ctx-size 19000
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device CUDA0
|
||||
--device-draft CUDA1
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
```
|
||||
|
||||
In this configuration, two GPUs are used: a 3090 (CUDA0) for the main model and a P40 (CUDA1) for the draft model. Although both models can fit on the 3090, relocating the draft model to the P40 freed up space for a larger context size. Despite the P40 being about 1/3rd the speed of the 3090, the small model still improved tokens per second.
|
||||
|
||||
Multiple tests were run with various parameters, and the fastest result was chosen for the configuration. In all tests, the 0.5B model produced the largest improvements to tokens per second.
|
||||
|
||||
Baseline: 33.92 tokens/second on 3090 without a draft model.
|
||||
|
||||
| draft-max | draft-min | draft-p-min | python | TS | swift |
|
||||
|-----------|-----------|-------------|--------|----|-------|
|
||||
| 16 | 1 | 0.9 | 71.64 | 55.55 | 48.06 |
|
||||
| 16 | 1 | 0.4 | 83.21 | 58.55 | 45.50 |
|
||||
| 16 | 1 | 0.1 | 79.72 | 55.66 | 43.94 |
|
||||
| 16 | 2 | 0.9 | 68.47 | 55.13 | 43.12 |
|
||||
| 16 | 2 | 0.4 | 82.82 | 57.42 | 48.83 |
|
||||
| 16 | 2 | 0.1 | 81.68 | 51.37 | 45.72 |
|
||||
| 16 | 4 | 0.9 | 66.44 | 48.49 | 42.40 |
|
||||
| 16 | 4 | 0.4 | _83.62_ (fastest)| _58.29_ | _50.17_ |
|
||||
| 16 | 4 | 0.1 | 82.46 | 51.45 | 40.71 |
|
||||
| 8 | 1 | 0.4 | 67.07 | 55.17 | 48.46 |
|
||||
| 4 | 1 | 0.4 | 50.13 | 44.96 | 40.79 |
|
||||
|
||||
The test script can be found in this [gist](https://gist.github.com/mostlygeek/da429769796ac8a111142e75660820f1). It is a simple curl script that prompts generating a snake game in Python, TypeScript, or Swift. Evaluation metrics were pulled from llama.cpp's logs.
|
||||
|
||||
```bash
|
||||
for lang in "python" "typescript" "swift"; do
|
||||
echo "Generating Snake Game in $lang using $model"
|
||||
curl -s --url http://localhost:8080/v1/chat/completions -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}" > /dev/null
|
||||
done
|
||||
```
|
||||
|
||||
Python consistently outperformed Swift in all tests, likely due to the 0.5B draft model being more proficient in generating Python code accepted by the larger 32B model.
|
||||
|
||||
## Chat
|
||||
|
||||
This configuration is for a regular chat use case. It produces approximately 13 tokens/second in typical use, up from ~9 tokens/second with only the 3xP40s. This is great news for P40 owners.
|
||||
|
||||
The models used are:
|
||||
|
||||
* [Bartowski Meta-Llama-3.1-70B-Instruct-GGUF](https://huggingface.co/bartowski/Meta-Llama-3.1-70B-Instruct-GGUF)
|
||||
* [Bartowski Llama-3.2-3B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF)
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"llama-70B":
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-be0e35
|
||||
--host 127.0.0.1 --port 9602
|
||||
--flash-attn --metrics
|
||||
--split-mode row
|
||||
--ctx-size 80000
|
||||
--model /mnt/nvme/models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf
|
||||
-ngl 99
|
||||
--model-draft /mnt/nvme/models/Llama-3.2-3B-Instruct-Q4_K_M.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 1
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA0
|
||||
--tensor-split 0,1,1,1
|
||||
```
|
||||
|
||||
In this configuration, Llama-3.1-70B is split across three P40s, and Llama-3.2-3B is on the 3090.
|
||||
|
||||
Some flags deserve further explanation:
|
||||
|
||||
* `--split-mode row` - increases inference speeds using multiple P40s by about 30%. This is a P40-specific feature.
|
||||
* `--tensor-split 0,1,1,1` - controls how the main model is split across the GPUs. This means 0% on the 3090 and an even split across the P40s. A value of `--tensor-split 0,5,4,1` would mean 0% on the 3090, 50%, 40%, and 10% respectively across the other P40s. However, this would exceed the available VRAM.
|
||||
* `--ctx-size 80000` - the maximum context size that can fit in the remaining VRAM.
|
||||
|
||||
## What is CUDA0, CUDA1, CUDA2, CUDA3?
|
||||
|
||||
These devices are the IDs used by llama.cpp.
|
||||
|
||||
```bash
|
||||
$ ./llama-server --list-devices
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||
ggml_cuda_init: found 4 CUDA devices:
|
||||
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Available devices:
|
||||
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 23892 MiB free)
|
||||
CUDA1: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||
CUDA2: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||
CUDA3: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||
```
|
||||
@@ -8,6 +8,37 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,10 +1,103 @@
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
+30
-4
@@ -3,31 +3,57 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var version string = "0"
|
||||
var commit string = "abcd1234"
|
||||
var date = "unknown"
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("version: %s (%s), built at %s\n", version, commit, date)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
config, err := proxy.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
http.HandleFunc("/", proxyManager.HandleFunc)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("Shutting down llama-swap")
|
||||
proxyManager.Shutdown()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := http.ListenAndServe(*listenStr, nil); err != nil {
|
||||
fmt.Printf("Error starting server: %v\n", err)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
@@ -3,47 +3,158 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set the header to text/plain
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
// Create a new Gin router
|
||||
r := gin.New()
|
||||
|
||||
fmt.Fprintln(w, *responseMessage)
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != *expectedModel {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
if echo == "" {
|
||||
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
if delay == "" {
|
||||
delay = "100ms"
|
||||
}
|
||||
|
||||
t, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/plain")
|
||||
for _, char := range echo {
|
||||
c.Writer.Write([]byte(string(char)))
|
||||
c.Writer.Flush()
|
||||
|
||||
// wait
|
||||
<-time.After(t)
|
||||
}
|
||||
})
|
||||
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/env", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
|
||||
// Get environment variables
|
||||
envVars := os.Environ()
|
||||
|
||||
// Write each environment variable to the response
|
||||
for _, envVar := range envVars {
|
||||
fmt.Fprintln(w, envVar)
|
||||
c.String(200, envVar)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up the /health endpoint handler function
|
||||
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
response := `{"status": "ok"}`
|
||||
w.Write([]byte(response))
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
address := ":" + *port // Address with the specified port
|
||||
fmt.Printf("Server is listening on port %s\n", *port)
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
// Start the server and log any error if it occurs
|
||||
if err := http.ListenAndServe(address, nil); err != nil {
|
||||
fmt.Printf("Error starting server: %s\n", err)
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
Handler: r.Handler(),
|
||||
}
|
||||
|
||||
// Disable logging if the --silent flag is set
|
||||
if *silent {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("simple-responder err: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
quit := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
|
||||
To run it:
|
||||
> curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq .
|
||||
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"model": "bge-reranker",
|
||||
"query": "Organic skincare products for sensitive skin",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.",
|
||||
"New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.",
|
||||
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.",
|
||||
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern – lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.",
|
||||
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.",
|
||||
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.",
|
||||
"针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。",
|
||||
"新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。",
|
||||
"敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。",
|
||||
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。"
|
||||
]
|
||||
}
|
||||
+34
-15
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -14,6 +15,8 @@ type ModelConfig struct {
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
@@ -21,26 +24,31 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, bool) {
|
||||
modelConfig, found := c.Models[modelName]
|
||||
if found {
|
||||
return modelConfig, true
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// Search through aliases to find the right config
|
||||
for _, config := range c.Models {
|
||||
for _, alias := range config.Aliases {
|
||||
if alias == modelName {
|
||||
return config, true
|
||||
}
|
||||
}
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
@@ -59,6 +67,14 @@ func LoadConfig(path string) (*Config, error) {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
@@ -68,7 +84,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
|
||||
// Split the command into arguments
|
||||
args := strings.Fields(cmdStr)
|
||||
args, err := shlex.Split(cmdStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
|
||||
+66
-16
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
@@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) {
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `models:
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
@@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) {
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
@@ -50,14 +61,33 @@ healthCheckTimeout: 15
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
|
||||
func TestModelConfigSanitizedCommand(t *testing.T) {
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
@@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) {
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestFindConfig(t *testing.T) {
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
// make make this shared between the different tests
|
||||
config := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
@@ -88,36 +121,53 @@ func TestFindConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 10,
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
// Test finding a model by its name
|
||||
modelConfig, found := config.FindConfig("model1")
|
||||
modelConfig, modelId, found := config.FindConfig("model1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model by its alias
|
||||
modelConfig, found = config.FindConfig("m1")
|
||||
modelConfig, modelId, found = config.FindConfig("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model that does not exist
|
||||
modelConfig, found = config.FindConfig("model3")
|
||||
modelConfig, modelId, found = config.FindConfig("model3")
|
||||
assert.False(t, found)
|
||||
assert.Equal(t, "", modelId)
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestSanitizeCommand(t *testing.T) {
|
||||
// Test a simple command
|
||||
args, err := SanitizeCommand("python model1.py")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py"}, args)
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err = SanitizeCommand(`python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`)
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
--arg3 123 \
|
||||
--arg4 '"string in string"'
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
func TestMain(m *testing.M) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
// Helper function to get the binary path
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
|
||||
// Create a process configuration
|
||||
return ModelConfig{
|
||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
@@ -0,0 +1,14 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>llama-swap</h1>
|
||||
<p>
|
||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,145 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Logs</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-controls {
|
||||
margin: 0.5em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between; /* Spaces out elements evenly */
|
||||
}
|
||||
#log-controls input {
|
||||
flex: 1;
|
||||
}
|
||||
#log-controls input:focus {
|
||||
outline: none; /* Ensures no outline is shown when the input is focused */
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 0.5em;
|
||||
padding: 1em;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
|
||||
.regex-error {
|
||||
background-color: #ff0000 !important;
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body {
|
||||
background-color: #333;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-stream {
|
||||
background: #444;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-controls input {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
|
||||
#log-controls button {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...</pre>
|
||||
<div id="log-controls">
|
||||
<input type="text" id="filter-input" placeholder="regex filter">
|
||||
<button id="clear-button">clear</button>
|
||||
</div>
|
||||
<script>
|
||||
const logStream = document.getElementById('log-stream');
|
||||
const filterInput = document.getElementById('filter-input');
|
||||
var logData = "";
|
||||
let regexFilter = null;
|
||||
|
||||
function setupEventSource() {
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
logData += event.data;
|
||||
render()
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
logData = "EventSource failed: " + err.message;
|
||||
};
|
||||
} else {
|
||||
logData = "SSE Not supported by this browser."
|
||||
}
|
||||
}
|
||||
|
||||
// poor-ai's react ¯\_(ツ)_/¯
|
||||
function render() {
|
||||
if (regexFilter) {
|
||||
const lines = logData.split('\n');
|
||||
const filteredLines = lines.filter(line => {
|
||||
return regexFilter === null || regexFilter.test(line);
|
||||
});
|
||||
|
||||
if (filteredLines.length > 0) {
|
||||
logStream.textContent = filteredLines.join('\n') + '\n';
|
||||
} else {
|
||||
logStream.textContent = "";
|
||||
}
|
||||
} else {
|
||||
logStream.textContent = logData;
|
||||
}
|
||||
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
}
|
||||
|
||||
function updateFilter() {
|
||||
const pattern = filterInput.value.trim();
|
||||
filterInput.classList.remove('regex-error');
|
||||
if (pattern) {
|
||||
try {
|
||||
regexFilter = new RegExp(pattern);
|
||||
} catch (e) {
|
||||
console.error("Invalid regex pattern:", e);
|
||||
regexFilter = null;
|
||||
filterInput.classList.add('regex-error');
|
||||
return
|
||||
}
|
||||
} else {
|
||||
regexFilter = null;
|
||||
}
|
||||
|
||||
render();
|
||||
}
|
||||
|
||||
filterInput.addEventListener('input', updateFilter);
|
||||
document.getElementById('clear-button').addEventListener('click', () => {
|
||||
filterInput.value = "";
|
||||
regexFilter = null;
|
||||
render();
|
||||
});
|
||||
setupEventSource();
|
||||
updateFilter();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,10 @@
|
||||
package proxy
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed html
|
||||
var htmlFiles embed.FS
|
||||
|
||||
func getHTMLFile(path string) ([]byte, error) {
|
||||
return htmlFiles.ReadFile("html/" + path)
|
||||
}
|
||||
+1
-1
@@ -46,7 +46,7 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
w.buffer = w.buffer.Next()
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
w.broadcast(p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentCmd *exec.Cmd
|
||||
currentConfig ModelConfig
|
||||
logMonitor *LogMonitor
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
return &ProxyManager{config: config, logMonitor: NewLogMonitor()}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints
|
||||
if r.URL.Path == "/v1/chat/completions" {
|
||||
// extracts the `model` from json body
|
||||
pm.proxyChatRequest(w, r)
|
||||
} else if r.URL.Path == "/v1/models" {
|
||||
pm.listModels(w, r)
|
||||
} else if r.URL.Path == "/logs" {
|
||||
pm.streamLogs(w, r)
|
||||
} else {
|
||||
pm.proxyRequest(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := r.Context().Done()
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
skipHistory := r.URL.Query().Has("skip")
|
||||
if !skipHistory {
|
||||
// Send history first
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
w.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if !r.URL.Query().Has("stream") {
|
||||
return
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
w.Write(msg)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModels(w http.ResponseWriter, _ *http.Request) {
|
||||
data := []interface{}{}
|
||||
for id := range pm.config.Models {
|
||||
data = append(data, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
http.Error(w, "Error encoding JSON", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// find the model configuration matching requestedModel
|
||||
modelConfig, found := pm.config.FindConfig(requestedModel)
|
||||
if !found {
|
||||
return fmt.Errorf("could not find configuration for %s", requestedModel)
|
||||
}
|
||||
|
||||
// no need to swap llama.cpp instances
|
||||
if pm.currentConfig.Cmd == modelConfig.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// kill the current running one to swap it
|
||||
if pm.currentCmd != nil {
|
||||
pm.currentCmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
// wait for it to end
|
||||
pm.currentCmd.Process.Wait()
|
||||
}
|
||||
|
||||
pm.currentConfig = modelConfig
|
||||
|
||||
args, err := modelConfig.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
|
||||
// logMonitor only writes to stdout
|
||||
// so the upstream's stderr will go to os.Stdout
|
||||
cmd.Stdout = pm.logMonitor
|
||||
cmd.Stderr = pm.logMonitor
|
||||
|
||||
cmd.Env = modelConfig.Env
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pm.currentCmd = cmd
|
||||
|
||||
// watch for the command to exist
|
||||
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
||||
|
||||
// monitor the command's exist status
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("command [%s] %s", strings.Join(cmd.Args, " "), err.Error()))
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for checkHealthEndpoint
|
||||
if err := pm.checkHealthEndpoint(cmdCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) checkHealthEndpoint(cmdCtx context.Context) error {
|
||||
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(cmdCtx, 250*time.Millisecond)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return context.Cause(ctx)
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(pm.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := pm.swapModel(model); err != nil {
|
||||
http.Error(w, fmt.Sprintf("unable to swap to model, %s", err.Error()), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
pm.proxyRequest(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
http.Error(w, "No upstream proxy", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
http.Error(w, writeErr.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateStarting ProcessState = ProcessState("starting")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// failed a health check on start and will not be recovered
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
healthCheckTimeout int
|
||||
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing shutdown state
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
state: StateStopped,
|
||||
shutdownCtx: ctx,
|
||||
shutdownCancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) setState(newState ProcessState) error {
|
||||
// enforce valid state transitions
|
||||
invalidTransition := false
|
||||
if p.state == StateStopped {
|
||||
// stopped -> starting
|
||||
if newState != StateStarting {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateStarting {
|
||||
// starting -> ready | failed | stopping
|
||||
if newState != StateReady && newState != StateFailed && newState != StateStopping {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateReady {
|
||||
// ready -> stopping
|
||||
if newState != StateStopping {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateStopping {
|
||||
// stopping -> stopped | shutdown
|
||||
if newState != StateStopped && newState != StateShutdown {
|
||||
invalidTransition = true
|
||||
}
|
||||
} else if p.state == StateFailed || p.state == StateShutdown {
|
||||
invalidTransition = true
|
||||
}
|
||||
|
||||
if invalidTransition {
|
||||
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
|
||||
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
func (p *Process) start() error {
|
||||
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||
}
|
||||
|
||||
// wait for the other start() to complete
|
||||
curState := p.CurrentState()
|
||||
|
||||
if curState == StateReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
if curState == StateStarting {
|
||||
p.waitStarting.Wait()
|
||||
|
||||
if state := p.CurrentState(); state != StateReady {
|
||||
return fmt.Errorf("start() failed current state: %v", state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// There is the possibility of a hard to replicate race condition where
|
||||
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
|
||||
// below, it's value has changed!
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
|
||||
// to transition from to StateReady
|
||||
|
||||
if p.state != StateStopped {
|
||||
if p.state == StateReady {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.setState(StateStarting); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.logMonitor
|
||||
p.cmd.Stderr = p.logMonitor
|
||||
p.cmd.Env = p.config.Env
|
||||
|
||||
err = p.cmd.Start()
|
||||
|
||||
if err != nil {
|
||||
p.setState(StateFailed)
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
}
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
|
||||
checkStartTime := time.Now()
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||
if checkEndpoint != "none" {
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
||||
context.Background(),
|
||||
checkStartTime.Add(maxDuration),
|
||||
)
|
||||
defer cancelHealthCheck()
|
||||
|
||||
// Health check loop
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-checkDeadline.Done():
|
||||
p.setState(StateFailed)
|
||||
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
|
||||
case <-p.shutdownCtx.Done():
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
default:
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
cancelHealthCheck()
|
||||
break loop
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
endTime, _ := checkDeadline.Deadline()
|
||||
ttl := time.Until(endTime)
|
||||
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
<-time.After(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
if p.config.UnloadAfter > 0 {
|
||||
// start a goroutine to check every second if
|
||||
// the process should be stopped
|
||||
go func() {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
if p.state != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for all inflight requests to complete and ticker
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return p.setState(StateReady)
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
// calling Stop() when state is invalid is a no-op
|
||||
if err := p.setState(StateStopping); err != nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// stop the process with a graceful exit timeout
|
||||
p.stopCommand(5 * time.Second)
|
||||
|
||||
if err := p.setState(StateStopped); err != nil {
|
||||
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down
|
||||
func (p *Process) Shutdown() {
|
||||
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
|
||||
p.shutdownCancel()
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.setState(StateStopping)
|
||||
|
||||
// 5 seconds to stop the process
|
||||
p.stopCommand(5 * time.Second)
|
||||
if err := p.setState(StateShutdown); err != nil {
|
||||
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
|
||||
}
|
||||
p.setState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||
defer cancelTimeout()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
|
||||
} else {
|
||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// got a response but it was not an OK
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// process is automatically started
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
|
||||
// Stop the process
|
||||
process.Stop()
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Proxy the request
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// should have automatically started the process again
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
|
||||
process.ProxyRequest(w, req1)
|
||||
|
||||
t.Log("sending slow first request (4 seconds)")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "1234")
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// ensure the TTL timeout does not race slow requests (see issue #25)
|
||||
t.Log("sending second request (1 second)")
|
||||
time.Sleep(time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req2)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// wait 5 seconds
|
||||
t.Log("sleep 5 seconds and check if unloaded")
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
"12345": "",
|
||||
"abcde": "",
|
||||
"fghij": "",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range results {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
|
||||
}(key)
|
||||
}
|
||||
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
|
||||
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
|
||||
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
|
||||
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
|
||||
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
|
||||
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
|
||||
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
|
||||
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := &Process{
|
||||
state: test.currentState,
|
||||
}
|
||||
|
||||
err := p.setState(test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.state != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Second * 2)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
@@ -0,0 +1,401 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
PROFILE_SPLIT_CHAR = ":"
|
||||
)
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentProcesses map[string]*Process
|
||||
logMonitor *LogMonitor
|
||||
ginEngine *gin.Engine
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
currentProcesses: make(map[string]*Process),
|
||||
logMonitor: NewLogMonitor(),
|
||||
ginEngine: gin.New(),
|
||||
}
|
||||
|
||||
if config.LogRequests {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// capture these because /upstream/:model rewrites them in c.Next()
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Stop timer
|
||||
duration := time.Since(start)
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
||||
clientIP,
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
method,
|
||||
path,
|
||||
c.Request.Proto,
|
||||
statusCode,
|
||||
bodySize,
|
||||
c.Request.UserAgent(),
|
||||
duration,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
||||
// respond with permissive OPTIONS for any endpoint
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
|
||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
|
||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
htmlData, err := getHTMLFile("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(htmlData)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
||||
c.Data(http.StatusOK, "image/x-icon", data)
|
||||
} else {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
// Disable console color for testing
|
||||
gin.DisableConsoleColor()
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Run(addr ...string) error {
|
||||
return pm.ginEngine.Run(addr...)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) StopProcesses() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
pm.stopProcesses()
|
||||
}
|
||||
|
||||
// for internal usage
|
||||
func (pm *ProxyManager) stopProcesses() {
|
||||
if len(pm.currentProcesses) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Stop()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
}
|
||||
|
||||
// Shutdown is called to shutdown all upstream processes
|
||||
// when llama-swap is shutting down.
|
||||
func (pm *ProxyManager) Shutdown() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// shutdown process in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
data = append(data, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(modelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
// check if model is part of the profile
|
||||
if profileName != "" {
|
||||
found := false
|
||||
for _, item := range pm.config.Profiles[profileName] {
|
||||
if item == realModelName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// exit early when already running, otherwise stop everything and swap
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// stop all running models
|
||||
pm.stopProcesses()
|
||||
|
||||
if profileName == "" {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
} else {
|
||||
for _, modelName := range pm.config.Profiles[profileName] {
|
||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestedProcessKey should exist due to swap
|
||||
return pm.currentProcesses[requestedProcessKey], nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
} else {
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||
var html strings.Builder
|
||||
|
||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
||||
|
||||
// Extract keys and sort them
|
||||
var modelIDs []string
|
||||
for modelID, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
||||
}
|
||||
html.WriteString("</ul></body></html>")
|
||||
c.Header("Content-Type", "text/html")
|
||||
c.String(http.StatusOK, html.String())
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
|
||||
// strip
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
if profileName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||
acceptHeader := c.GetHeader("Accept")
|
||||
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
c.JSON(statusCode, gin.H{"error": message})
|
||||
} else {
|
||||
c.String(statusCode, message)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses()
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
}
|
||||
|
||||
func splitRequestedModel(requestedModel string) (string, string) {
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
return profileName, modelName
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
logsHTML, err := getHTMLFile("logs.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(logsHTML)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
return
|
||||
}
|
||||
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
_, err := c.Writer.Write(msg)
|
||||
if err != nil {
|
||||
// just break the loop if we can't write for some reason
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
c.SSEvent("message", string(msg))
|
||||
c.Writer.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,351 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||
|
||||
}
|
||||
|
||||
// make sure there's only one loaded model
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileModel1 := ProcessKeyName("test", model1)
|
||||
profileModel2 := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1, model2},
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
for modelID, requestedModel := range map[string]string{
|
||||
"model1": profileModel1,
|
||||
"model2": profileModel2,
|
||||
} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelID)
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Len(t, proxy.currentProcesses, 2)
|
||||
_, exists := proxy.currentProcesses[profileModel1]
|
||||
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||
|
||||
_, exists = proxy.currentProcesses[profileModel2]
|
||||
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||
}
|
||||
|
||||
// When a request for a different model comes in ProxyManager should wait until
|
||||
// the first request is complete before swapping. Both requests should complete
|
||||
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
results := map[string]string{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range config.Models {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Len(t, results, len(config.Models))
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Add("Origin", "i-am-the-origin")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
// Check the response status code
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check for Access-Control-Allow-Origin
|
||||
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
// Parse the JSON response
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
}
|
||||
|
||||
for _, model := range response.Data {
|
||||
modelID, ok := model["id"].(string)
|
||||
assert.True(t, ok, "model ID should be a string")
|
||||
_, exists := expectedModels[modelID]
|
||||
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
||||
delete(expectedModels, modelID)
|
||||
|
||||
object, ok := model["object"].(string)
|
||||
assert.True(t, ok, "object should be a string")
|
||||
assert.Equal(t, "model", object)
|
||||
|
||||
created, ok := model["created"].(float64)
|
||||
assert.True(t, ok, "created should be a number")
|
||||
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
||||
|
||||
ownedBy, ok := model["owned_by"].(string)
|
||||
assert.True(t, ok, "owned_by should be a string")
|
||||
assert.Equal(t, "llama-swap", ownedBy)
|
||||
}
|
||||
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileMemberName := ProcessKeyName("test", model1)
|
||||
profileNonMemberName := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1},
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
// actual member of profile
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
}
|
||||
|
||||
// actual model, but non-member will 404
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
|
||||
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
||||
model2Config.Proxy = "http://localhost:10002/"
|
||||
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2", "model3"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Start all the processes
|
||||
var wg sync.WaitGroup
|
||||
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||
wg.Add(1)
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||
//fmt.Println(w.Code, w.Body.String())
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-time.After(time.Second)
|
||||
proxy.Shutdown()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proc, err := proxy.swapModel("model1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proc)
|
||||
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
req := httptest.NewRequest("GET", "/unload", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
assert.Len(t, proxy.currentProcesses, 0)
|
||||
}
|
||||
|
||||
// issue 62, strip profile slug from model name
|
||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "ok")
|
||||
}
|
||||
Reference in New Issue
Block a user