Compare commits

...

73 Commits

Author SHA1 Message Date
Benson Wong 314d2f2212 remove cmd_stop configuration and functionality from PR #40 (#44)
* remove cmd_stop functionality from #40
2025-01-31 12:42:44 -08:00
Benson Wong fad25f3e11 Use client request context in proxy request (#43)
Canceled or closed HTTP requests from clients will also stop the proxied
HTTP requests to upstreamed servers.
2025-01-31 10:21:49 -08:00
Benson Wong 2c3e3e27f7 Support OPTIONS requests (#42)
Add middleware that responds with permissive OPTIONS headers
for all request paths.
2025-01-31 10:09:07 -08:00
Benson Wong baeb0c4e7f Add cmd_stop configuration to better support docker (#35)
Add `cmd_stop` to model configuration to run a command instead of sending a SIGTERM to shutdown a process before swapping.
2025-01-30 16:59:57 -08:00
Benson Wong 2833517eef Improve handling of process that do not handle SIGTERM (#38)
- Process TTL goroutine did not have a return after .Stop()
- Improve logging
- Add test TestProcess_LowTTLValue to measure SIGTERM error rate
2025-01-20 14:39:52 -08:00
Benson Wong abdc2bfdb3 Fix panic when requesting non-members of profiles
A panic occurs when a request for an invalid profile:model pair is made.
The edge case is that the profile exists and the model exists but they're
not configured as a pair.

This adds an additional check to make sure the profile:model pair is
valid before attempting to swap the model.
2025-01-16 12:06:38 -08:00
Benson Wong c3b834737f Update README.md 2025-01-13 22:37:30 -08:00
Benson Wong 3c8e727b73 Update README.md 2025-01-12 19:48:35 -08:00
Benson Wong 3a1e9f81f1 support TTS /v1/audio/speech (#36) 2025-01-12 16:27:01 -08:00
Benson Wong 72c883f36c Update README.md 2025-01-02 09:01:51 -08:00
Benson Wong 1b04d034cf Update README.md 2025-01-02 08:59:11 -08:00
Benson Wong 2e45f5692a Update README.md
Improve README documentation.
2025-01-01 12:51:24 -08:00
Benson Wong c97b80bdfe Update README.md 2025-01-01 12:25:45 -08:00
Benson Wong ae3ef9bc39 Refactor UI (#33)
- add html to / instead of 404
- add client side regex to /logs
2024-12-23 19:48:59 -08:00
Benson Wong db6715bec3 update golang.org/x/net -> v0.33.0 for dependabot 2024-12-20 11:28:32 -08:00
Benson Wong da5d9e8a6a fix HTTP logging so true path is printed 2024-12-20 11:25:01 -08:00
Benson Wong 84b667ca7a improve logging and error reporting for troubleshooting 2024-12-20 10:46:56 -08:00
Benson Wong 29657106fc add more OpenAI API supported in README 2024-12-20 10:08:20 -08:00
Benson Wong 9c8860471e support v1/rerank endpoint 2024-12-17 21:22:25 -08:00
Benson Wong 9b4e3f307e rename proxy handler 2024-12-17 17:25:10 -08:00
Benson Wong 6fe37c3abf support /v1/embeddings (#4) 2024-12-17 17:25:10 -08:00
Benson Wong 7f45493a37 Update README.md 2024-12-17 14:45:41 -08:00
Benson Wong 891f6a5b5a Add /upstream endpoint (#30)
* remove catch-all route to upstream proxy (it was broken anyways)
* add /upstream/:model_id to swap and route to upstream path
* add /upstream HTML endpoint and unlisted option
* add /upstream endpoint to show a list of available models
* add `unlisted` configuration option to omit a model from /v1/models and /upstream lists
* add favicon.ico
2024-12-17 14:37:44 -08:00
Benson Wong 7183f6b43d fix bad logging due to wrong []byte used #28 2024-12-16 16:22:14 -08:00
Benson Wong d89bfeb441 add .DS_Store to .gitignore 2024-12-16 12:30:31 -08:00
Benson Wong 9a0c6bed40 Improve stop exceptions (#28) (#29)
Stop Process TTL goroutine when process is not ready (#28)

- fix issue where the goroutine will continue even though the child
  process is no longer running and the Process' state is not Ready
- fix issue where some logs were going to stdout instead of p.logMonitor
  causing them to not show up in the /logs
- add units to unloading model message
2024-12-16 12:29:25 -08:00
Benson Wong d6ca535939 tweak release tagging so it is not based on number of commits 2024-12-14 15:46:10 -08:00
Benson Wong 27302c0c02 change llama-swap to use goreleaser default ldflag values 2024-12-14 10:30:06 -08:00
Benson Wong d4e22cceaa Fix security vulnerability with golang.org/x/crypto
- does not affect the project as llama-swap does not use the crypto
  libraries
- good practice to keep security deps updated!
2024-12-14 10:20:22 -08:00
Benson Wong 4c94927658 Move release to Makefile out of goreleaser
- less complexity
- easier
- goreleaser, github, pipelines: 1...  mostlygeek: 0
2024-12-14 10:16:46 -08:00
Benson Wong a955a4a5c0 create tag to release 2024-12-14 10:07:20 -08:00
Benson Wong 22d3f1a4f9 Change versioning to use git commits counts instead of semver
- less work for me
- more frequent releases
2024-12-14 09:53:13 -08:00
Benson Wong e2443251ad update readme 2024-12-09 19:14:49 -08:00
Benson Wong 5fbd53c616 delay TTL check until after all requests are complete (#25)
- fixes #25 where requests that last longer than the TTL will cause the
  process to be unloaded before the next request.
- new behavior, TTL waits until all requests are complete before
  checking timeout
2024-12-09 19:08:03 -08:00
Benson Wong 97dae50dc4 update readme 2024-12-08 21:34:16 -08:00
Benson Wong cb978f760f add web interface to /logs 2024-12-08 21:26:22 -08:00
Benson Wong 387f0ef6c4 use new timings data in server response in run-benchmark.sh 2024-12-03 20:48:36 -08:00
Benson Wong 18c134624d Add Access-Control-Allow-Origin CORS header to /v1/models endpoint
- match behavior of llama.cpp where the Origin in request is used
- add test for listModelsHandler
2024-12-03 15:53:59 -08:00
Benson Wong da2326bdc7 add example: optimizing code generation 2024-12-03 10:25:43 -08:00
Benson Wong da46545630 fix profile example in README 2024-12-01 10:13:31 -08:00
Benson Wong 04b4760e7e change profile split character to : (colon) (#21)
- change from `/` to `:` for multiple models loaded as part of a profile
- breaking change now, but allows for more compatibility with other inference engines that may have model references like `coding:Qwen/Qwen-2.5-Coder-32B`
2024-12-01 09:10:50 -08:00
Benson Wong 9fc5d5b5eb improve cmd parsing (#22)
Switch from using a naive strings.Fields() to shlex.Split() for parsing the model startup command into a string[]. This makes parsing much more reliable around newlines, quotes, etc.
2024-12-01 09:02:58 -08:00
Benson Wong cf82b3c633 Improve Concurrency and Parallel Request Handling (#19)
Rewrite the swap behaviour so that in-flight requests block process swapping until they are completed. 

Additionally: 

- add tests for parallel requests with proxy.ProxyManager and proxy.Process
- improve Process startup behaviour and simplified the code 
- stopping of processes are sent SIGTERM and have 5 seconds to terminate, before they are killed
2024-11-30 15:24:42 -08:00
Benson Wong e363f8f498 clean up writing with AI :b 2024-11-28 22:12:44 -08:00
Benson Wong c9629cf3a2 add speculative decoding example 2024-11-28 22:07:22 -08:00
Benson Wong 50426935a4 . 2024-11-28 22:06:29 -08:00
Benson Wong 2fceb78e8d Add examples 2024-11-28 22:05:41 -08:00
Ikko Eltociear Ashimine 9a81c53664 chore: update process_test.go (#17)
nonexistant -> nonexistent
2024-11-26 10:20:16 -08:00
Benson Wong 716d37de82 Update README.md
fix grammar
2024-11-25 12:35:00 -08:00
Benson Wong 73ad85ea69 Implement Multi-Process Handling (#7)
Refactor code to support starting of multiple back end llama.cpp servers. This functionality is exposed as `profiles` to create a simple configuration format. 

Changes: 

* refactor proxy tests to get ready for multi-process support
* update proxy/ProxyManager to support multiple processes (#7)
* Add support for Groups in configuration
* improve handling of Model alias configs
* implement multi-model swapping
* improve code clarity for swapModel
* improve docs, rename groups to profiles in config
2024-11-23 19:45:13 -08:00
Benson Wong 533162ce6a add support for automatically unloading a model (#10) (#14)
* Make starting upstream process on-demand (#10)
* Add automatic unload of model after TTL is reached
* add `ttl` configuration parameter to models in seconds, default is 0 (never unload)
2024-11-19 16:32:51 -08:00
Benson Wong ba39ed4c18 Add support for legacy v1/completions API (#12) 2024-11-19 09:57:39 -08:00
Benson Wong 21f54f96c2 Merge pull request #13 from mostlygeek/set-content-length
Dechunk HTTP requests by default (#11)
2024-11-19 09:46:03 -08:00
Benson Wong 7eec51f3f2 Dechunk HTTP requests by default (#11)
ProxyManager already has all the Request body's data. There is no never
a need to use chunked transfer encoding to the upstream process.
2024-11-19 09:40:44 -08:00
Benson Wong 5021e0f299 remove the process handler override 2024-11-18 21:26:39 -08:00
Benson Wong c9233d2c9a use gin instead of standard http lib in main 2024-11-18 15:58:28 -08:00
Benson Wong a33ac6f8fb update README 2024-11-18 15:37:50 -08:00
Benson Wong 401aa88949 move log handlers to separate file 2024-11-18 15:33:06 -08:00
Benson Wong e9e88fd229 rename proxy.go to proxymanager.go 2024-11-18 15:30:34 -08:00
Benson Wong c3b4bb1684 use gin for http server 2024-11-18 15:30:16 -08:00
Benson Wong e5c909ddf7 add tests for proxy.Process 2024-11-17 20:49:14 -08:00
Benson Wong 36a31f450f add proxy.Process to manage upstream proxy logic 2024-11-17 16:41:15 -08:00
Benson Wong a8e5ee13b9 Add logging with pipes example to README 2024-11-15 09:10:43 -08:00
Benson Wong 5944a86e86 fix early timeout bug 2024-11-09 20:08:40 -08:00
Benson Wong 63d4a7d0eb Improve LogMonitor to handle empty writes and ensure buffer immutability
- Add a check to return immediately if the write buffer is empty
- Create a copy of new history data to ensure it is immutable
- Update the `GetHistory` method to use the `any` type for the buffer interface
- Add a test case to verify that the buffer remains unchanged
  even if the original message is modified after writing
2024-11-02 10:41:23 -07:00
Benson Wong f45469f7ff Merge pull request #8 from mostlygeek/improve-upstream-monitoring-issue5
Improvements to handling of the upstream process so errors happen whenever one of these is first:

    the health check timeout is reached waiting for the upstream process to be ready
    the upstream process exits unexpectedly

With this change llama-swap is more compatible with use cases like containerized upstream services (#5) which pull the container before HTTP endpoints are ready.
2024-11-01 15:28:06 -07:00
Benson Wong 34f9fd7340 Improve timeout and exit handling of child processes. fix #3 and #5
llama-swap only waited a maximum of 5 seconds for an upstream
HTTP server to be available. If it took longer than that it will error
out the request. Now it will wait up to the configured healthCheckTimeout
or the upstream process unexpectedly exits.
2024-11-01 14:32:39 -07:00
Benson Wong 8448efa7fc revise health check logic to not error on 5 second timeout 2024-11-01 09:42:37 -07:00
Benson Wong 8cf2a389d8 Refactor log implementation
- use []byte instead of unnecessary string conversions
- make LogManager.Broadcast private
- make LogManager.GetHistory public
- add tests
2024-10-31 12:16:54 -07:00
Benson Wong 0f133f5b74 Add /logs endpoint to monitor upstream processes
- outputs last 10KB of logs from upstream processes
- supports streaming
2024-10-30 21:02:30 -07:00
Benson Wong 1510b3fbd9 clean up README 2024-10-22 10:37:45 -07:00
Benson Wong 0f8a8e70f1 add header image 2024-10-22 10:30:30 -07:00
Benson Wong 6c3819022c Add compatibility with OpenAI /v1/models endpoint to list models 2024-10-21 15:38:12 -07:00
32 changed files with 2547 additions and 316 deletions
+1 -1
View File
@@ -30,4 +30,4 @@ jobs:
version: '~> v2'
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+2 -1
View File
@@ -2,4 +2,5 @@
.env
build/
dist/
.vscode
.vscode
.DS_Store
+35 -4
View File
@@ -2,6 +2,16 @@
APP_NAME = llama-swap
BUILD_DIR = build
# Get the current Git hash
GIT_HASH := $(shell git rev-parse --short HEAD)
ifneq ($(shell git status --porcelain),)
# There are untracked changes
GIT_HASH := $(GIT_HASH)+
endif
# Capture the current build date in RFC3339 format
BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
# Default target: Builds binaries for both OSX and Linux
all: mac linux simple-responder
@@ -9,24 +19,45 @@ all: mac linux simple-responder
clean:
rm -rf $(BUILD_DIR)
test:
go test -short -v ./proxy
test-all:
go test -v ./proxy
# Build OSX binary
mac:
@echo "Building Mac binary..."
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary
linux:
@echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
# for testing things
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
go build -o $(BUILD_DIR)/simple-responder misc/simple-responder/simple-responder.go
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
mkdir -p $(BUILD_DIR)
# Create a new release tag
release:
@echo "Checking for unstaged changes..."
@if [ -n "$(shell git status --porcelain)" ]; then \
echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \
exit 1; \
fi
# Get the highest tag in v{number} format, increment it, and create a new tag
@highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \
new_tag="v$$(( $${highest_tag#v} + 1 ))"; \
echo "tagging new version: $$new_tag"; \
git tag "$$new_tag";
# Phony targets
.PHONY: all clean osx linux
+118 -25
View File
@@ -1,57 +1,156 @@
# llama-swap
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, so let's swap llama-server instead!
![llama-swap header image](header.jpeg)
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
# Introduction
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
- ✅ easy to deploy: single binary with no dependencies
- ✅ full control over llama-server's startup settings
- ✅ ❤️ for nvidia P40 users who are rely on llama.cpp for inference
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
## How does it work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
## Features:
- ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ On-demand model switching
- ✅ Full control over server settings per model
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- ✅ Multiple GPU support
- ✅ Docker and Podman support
- ✅ Run multiple models at once with `profiles`
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
## config.yaml
llama-swap's configuration purposefully simple.
llama-swap's configuration is purposefully simple.
```yaml
# Seconds to wait for llama.cpp to load and be ready to serve requests
# Default (and minimum) is 15 seconds
healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false
logRequests: true
# define valid model values and the upstream server start
models:
"llama":
cmd: "llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf"
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
# Where to proxy to, important it matches this format
proxy: "http://127.0.0.1:8999"
# where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999
# aliases model names to use this configuration for
# aliases names to use this model for
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# wait for this path to return an HTTP 200 before serving requests
# defaults to /health to match llama.cpp
#
# use "none" to skip endpoint checking. This may cause requests to fail
# until the server is ready
checkEndpoint: "/custom-endpoint"
# check this path for an HTTP 200 OK before serving requests
# default: /health to match llama.cpp
# use "none" to skip endpoint checking, but may cause HTTP errors
# until the model is ready
checkEndpoint: /custom-endpoint
# automatically unload the model after this many seconds
# ttl values must be a value greater than 0
# default: 0 = never unload model
ttl: 60
"qwen":
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
cmd: "llama-server --port 8999 -m path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"
proxy: "http://127.0.0.1:8999"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
unlisted: true
# Docker Support (v26.1.4+ required!)
"docker-llama":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles make it easy to managing multi model (and gpu) configurations.
#
# Tips:
# - each model must be listening on a unique address and port
# - the model name is in this format: "profile_name:model", like "coding:qwen"
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
- "qwen"
- "llama"
```
## Installation
### Advanced Examples
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
### Installation
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
* _Note: Windows currently untested._
1. Run the binary with `llama-swap --config path/to/config.yaml`
### Building from source
1. Install golang for your system
1. `git clone git@github.com:mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## Monitoring Logs
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
Of course, CLI access is also supported:
```
# sends up to the last 10KB of logs
curl http://host/logs'
# streams logs
curl -Ns 'http://host/logs/stream'
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
# skips history and just streams new log entries
curl -Ns 'http://host/logs/stream?no-history'
```
## Systemd Unit Files
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
@@ -76,9 +175,3 @@ StartLimitInterval=30
[Install]
WantedBy=multi-user.target
```
## Building from Source
1. Install golang for your system
1. run `make clean all`
1. binaries will be built into `build/` directory
+59 -8
View File
@@ -1,14 +1,17 @@
# Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds
healthCheckTimeout: 60
healthCheckTimeout: 15
# Log HTTP requests helpful for troubleshoot, defaults to False
logRequests: true
models:
"llama":
cmd: >
models/llama-server-osx
--port 8999
-m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
--port 9001
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:9001
# list of model name aliases this llama.cpp instance can serve
aliases:
@@ -17,12 +20,48 @@ models:
# check this path for a HTTP 200 response for the server to be ready
checkEndpoint: /health
# unload model after 5 seconds
ttl: 5
"qwen":
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9002
aliases:
- gpt-3.5-turbo
# Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic":
proxy: http://127.0.0.1:9005
cmd: >
models/llama-server-osx --port 9005
-m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192
--batch-size 8192
--rope-scaling yarn
--rope-freq-scale 0.75
-ngl 99
--embeddings
# Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker":
proxy: http://127.0.0.1:9006
cmd: >
models/llama-server-osx --port 9006
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192
--reranking
# Docker Support (v26.1.4+ required!)
"dockertest":
proxy: "http://127.0.0.1:9790"
cmd: >
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
"simple":
# example of setting environment variables
env:
@@ -30,12 +69,24 @@ models:
- env1=hello
cmd: build/simple-responder --port 8999
proxy: http://127.0.0.1:8999
unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail
# until the upstream server is ready for traffic
checkEndpoint: none
# don't use this, just for testing if things are broken
# don't use these, just for testing if things are broken
"broken":
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
proxy: http://127.0.0.1:8999
proxy: http://127.0.0.1:8999
unlisted: true
"broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000
unlisted: true
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"
+6
View File
@@ -0,0 +1,6 @@
# Example Configs and Use Cases
A collections of usecases and examples for getting the most out of llama-swap.
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
+123
View File
@@ -0,0 +1,123 @@
# Optimizing Code Generation with llama-swap
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
**Benchmark Scenarios**
Three scenarios are tested:
- 3090-only: Just the main model on the 3090
- 3090-with-draft: the main and draft models on the 3090
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
**Available Devices**
Use the following command to list available devices IDs for the configuration:
```
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: Tesla P40, compute capability 6.1, VMM: yes
Device 2: Tesla P40, compute capability 6.1, VMM: yes
Device 3: Tesla P40, compute capability 6.1, VMM: yes
Available devices:
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
```
**Configuration**
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
```yaml
models:
"3090-only":
proxy: "http://127.0.0.1:9503"
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--ctx-size 32768
--cache-type-k q8_0 --cache-type-v q8_0
"3090-with-draft":
proxy: "http://127.0.0.1:9503"
# --ctx-size 28500 max that can fit on 3090 after draft model
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device-draft CUDA0
--ctx-size 28500
--cache-type-k q8_0 --cache-type-v q8_0
"3090-P40-draft":
proxy: "http://127.0.0.1:9503"
cmd: >
/mnt/nvme/llama-server/llama-server-f3252055
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--device CUDA0
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device-draft CUDA1
--ctx-size 32768
--cache-type-k q8_0 --cache-type-v q8_0
```
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
**Running the Benchmark**
To run the benchmark, execute the following commands:
1. `llama-swap -config benchmark-config.yaml`
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
**Results (tokens/second)**
| model | python | typescript | swift |
|-----------------|--------|------------|-------|
| 3090-only | 34.03 | 34.01 | 34.01 |
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
Happy coding!
+40
View File
@@ -0,0 +1,40 @@
#!/usr/bin/env bash
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
# It was created to test the effects of speculative decoding and the various draft settings on performance.
#
# Writing code with a low temperature seems to provide fairly consistent logic.
#
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
# Example: ./benchmark.sh http://localhost:8080 model1 model2
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <url> <model1> [model2 ...]"
exit 1
fi
url=$1; shift
echo "model,python,typescript,swift"
for model in "$@"; do
echo -n "$model,"
for lang in "python" "typescript" "swift"; do
# expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548
# (Dec 3rd/2024)
time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second)
if [ $? -ne 0 ]; then
time="error"
exit 1
fi
if [ "$lang" != "swift" ]; then
printf "%0.2f tps," $time
else
printf "%0.2f tps\n" $time
fi
done
done
+124
View File
@@ -0,0 +1,124 @@
# Speculative Decoding
Speculative decoding can significantly improve the tokens per second. However, this comes at the cost of increased VRAM usage for the draft model. The examples provided are based on a server with three P40s and one 3090.
## Coding Use Case
This example uses Qwen2.5 Coder 32B with the 0.5B model as a draft. A quantization of Q8_0 was chosen for the draft model, as quantization has a greater impact on smaller models.
The models used are:
* [Bartowski Qwen2.5-Coder-32B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF)
* [Bartowski Qwen2.5-Coder-0.5B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF)
The llama-swap configuration is as follows:
```yaml
models:
"qwen-coder-32b-q4":
# main model on 3090, draft on P40 #1
cmd: >
/mnt/nvme/llama-server/llama-server-be0e35
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
-ngl 99
--ctx-size 19000
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
-ngld 99
--draft-max 16
--draft-min 4
--draft-p-min 0.4
--device CUDA0
--device-draft CUDA1
proxy: "http://127.0.0.1:9503"
```
In this configuration, two GPUs are used: a 3090 (CUDA0) for the main model and a P40 (CUDA1) for the draft model. Although both models can fit on the 3090, relocating the draft model to the P40 freed up space for a larger context size. Despite the P40 being about 1/3rd the speed of the 3090, the small model still improved tokens per second.
Multiple tests were run with various parameters, and the fastest result was chosen for the configuration. In all tests, the 0.5B model produced the largest improvements to tokens per second.
Baseline: 33.92 tokens/second on 3090 without a draft model.
| draft-max | draft-min | draft-p-min | python | TS | swift |
|-----------|-----------|-------------|--------|----|-------|
| 16 | 1 | 0.9 | 71.64 | 55.55 | 48.06 |
| 16 | 1 | 0.4 | 83.21 | 58.55 | 45.50 |
| 16 | 1 | 0.1 | 79.72 | 55.66 | 43.94 |
| 16 | 2 | 0.9 | 68.47 | 55.13 | 43.12 |
| 16 | 2 | 0.4 | 82.82 | 57.42 | 48.83 |
| 16 | 2 | 0.1 | 81.68 | 51.37 | 45.72 |
| 16 | 4 | 0.9 | 66.44 | 48.49 | 42.40 |
| 16 | 4 | 0.4 | _83.62_ (fastest)| _58.29_ | _50.17_ |
| 16 | 4 | 0.1 | 82.46 | 51.45 | 40.71 |
| 8 | 1 | 0.4 | 67.07 | 55.17 | 48.46 |
| 4 | 1 | 0.4 | 50.13 | 44.96 | 40.79 |
The test script can be found in this [gist](https://gist.github.com/mostlygeek/da429769796ac8a111142e75660820f1). It is a simple curl script that prompts generating a snake game in Python, TypeScript, or Swift. Evaluation metrics were pulled from llama.cpp's logs.
```bash
for lang in "python" "typescript" "swift"; do
echo "Generating Snake Game in $lang using $model"
curl -s --url http://localhost:8080/v1/chat/completions -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}" > /dev/null
done
```
Python consistently outperformed Swift in all tests, likely due to the 0.5B draft model being more proficient in generating Python code accepted by the larger 32B model.
## Chat
This configuration is for a regular chat use case. It produces approximately 13 tokens/second in typical use, up from ~9 tokens/second with only the 3xP40s. This is great news for P40 owners.
The models used are:
* [Bartowski Meta-Llama-3.1-70B-Instruct-GGUF](https://huggingface.co/bartowski/Meta-Llama-3.1-70B-Instruct-GGUF)
* [Bartowski Llama-3.2-3B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF)
```yaml
models:
"llama-70B":
cmd: >
/mnt/nvme/llama-server/llama-server-be0e35
--host 127.0.0.1 --port 9602
--flash-attn --metrics
--split-mode row
--ctx-size 80000
--model /mnt/nvme/models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf
-ngl 99
--model-draft /mnt/nvme/models/Llama-3.2-3B-Instruct-Q4_K_M.gguf
-ngld 99
--draft-max 16
--draft-min 1
--draft-p-min 0.4
--device-draft CUDA0
--tensor-split 0,1,1,1
```
In this configuration, Llama-3.1-70B is split across three P40s, and Llama-3.2-3B is on the 3090.
Some flags deserve further explanation:
* `--split-mode row` - increases inference speeds using multiple P40s by about 30%. This is a P40-specific feature.
* `--tensor-split 0,1,1,1` - controls how the main model is split across the GPUs. This means 0% on the 3090 and an even split across the P40s. A value of `--tensor-split 0,5,4,1` would mean 0% on the 3090, 50%, 40%, and 10% respectively across the other P40s. However, this would exceed the available VRAM.
* `--ctx-size 80000` - the maximum context size that can fit in the remaining VRAM.
## What is CUDA0, CUDA1, CUDA2, CUDA3?
These devices are the IDs used by llama.cpp.
```bash
$ ./llama-server --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: Tesla P40, compute capability 6.1, VMM: yes
Device 2: Tesla P40, compute capability 6.1, VMM: yes
Device 3: Tesla P40, compute capability 6.1, VMM: yes
Available devices:
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 23892 MiB free)
CUDA1: Tesla P40 (24438 MiB, 24290 MiB free)
CUDA2: Tesla P40 (24438 MiB, 24290 MiB free)
CUDA3: Tesla P40 (24438 MiB, 24290 MiB free)
```
+27
View File
@@ -8,6 +8,33 @@ require (
)
require (
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+83
View File
@@ -1,10 +1,93 @@
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 261 KiB

+30 -4
View File
@@ -3,31 +3,57 @@ package main
import (
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy"
)
var version string = "0"
var commit string = "abcd1234"
var date = "unknown"
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build")
flag.Parse() // Parse the command-line flags
if *showVersion {
fmt.Printf("version: %s (%s), built at %s\n", version, commit, date)
os.Exit(0)
}
config, err := proxy.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode)
} else {
gin.SetMode(gin.ReleaseMode)
}
proxyManager := proxy.New(config)
http.HandleFunc("/", proxyManager.HandleFunc)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
fmt.Println("Shutting down llama-swap")
proxyManager.StopProcesses()
os.Exit(0)
}()
fmt.Println("llama-swap listening on " + *listenStr)
if err := http.ListenAndServe(*listenStr, nil); err != nil {
fmt.Printf("Error starting server: %v\n", err)
if err := proxyManager.Run(*listenStr); err != nil {
fmt.Printf("Server error: %v\n", err)
os.Exit(1)
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

+105 -15
View File
@@ -3,47 +3,137 @@ package main
import (
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
)
func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on")
// Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with")
silent := flag.Bool("silent", false, "disable all logging")
flag.Parse() // Parse the command-line flags
// Set up the handler function using the provided response message
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Set the header to text/plain
w.Header().Set("Content-Type", "text/plain")
// Create a new Gin router
r := gin.New()
fmt.Fprintln(w, *responseMessage)
// Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
c.String(200, *responseMessage)
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo")
delay := c.Query("delay")
if echo == "" {
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
}
// Parse the duration
if delay == "" {
delay = "100ms"
}
t, err := time.ParseDuration(delay)
if err != nil {
c.Header("Content-Type", "text/plain")
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
return
}
c.Header("Content-Type", "text/plain")
for _, char := range echo {
c.Writer.Write([]byte(string(char)))
c.Writer.Flush()
// wait
<-time.After(t)
}
})
r.GET("/test", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
r.GET("/env", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
// Get environment variables
envVars := os.Environ()
// Write each environment variable to the response
for _, envVar := range envVars {
fmt.Fprintln(w, envVar)
c.String(200, envVar)
}
})
// Set up the /health endpoint handler function
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
response := `{"status": "ok"}`
w.Write([]byte(response))
r.GET("/health", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.JSON(200, gin.H{"status": "ok"})
})
address := ":" + *port // Address with the specified port
fmt.Printf("Server is listening on port %s\n", *port)
r.GET("/", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
})
// Start the server and log any error if it occurs
if err := http.ListenAndServe(address, nil); err != nil {
fmt.Printf("Error starting server: %s\n", err)
address := "127.0.0.1:" + *port // Address with the specified port
srv := &http.Server{
Addr: address,
Handler: r.Handler(),
}
// Disable logging if the --silent flag is set
if *silent {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
log.SetOutput(io.Discard)
}
go func() {
log.Printf("simple-responder listening on %s\n", address)
// service connections
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("simple-responder err: %s\n", err)
}
}()
// Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds.
quit := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("simple-responder shutting down")
}
+4
View File
@@ -0,0 +1,4 @@
The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510
To run it:
> curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq .
+17
View File
@@ -0,0 +1,17 @@
{
"model": "bge-reranker",
"query": "Organic skincare products for sensitive skin",
"top_n": 3,
"documents": [
"Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.",
"New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.",
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.",
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.",
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.",
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.",
"针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。",
"新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。",
"敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。",
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。"
]
}
+34 -15
View File
@@ -5,6 +5,7 @@ import (
"os"
"strings"
"github.com/google/shlex"
"gopkg.in/yaml.v3"
)
@@ -14,6 +15,8 @@ type ModelConfig struct {
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
@@ -21,26 +24,31 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
}
type Config struct {
Models map[string]ModelConfig `yaml:"models"`
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"`
// map aliases to actual model IDs
aliases map[string]string
}
func (c *Config) FindConfig(modelName string) (ModelConfig, bool) {
modelConfig, found := c.Models[modelName]
if found {
return modelConfig, true
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
}
// Search through aliases to find the right config
for _, config := range c.Models {
for _, alias := range config.Aliases {
if alias == modelName {
return config, true
}
}
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
if realName, found := c.RealModelName(modelName); !found {
return ModelConfig{}, "", false
} else {
return c.Models[realName], realName, true
}
return ModelConfig{}, false
}
func LoadConfig(path string) (*Config, error) {
@@ -59,6 +67,14 @@ func LoadConfig(path string) (*Config, error) {
config.HealthCheckTimeout = 15
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
config.aliases[alias] = modelName
}
}
return &config, nil
}
@@ -68,7 +84,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
// Split the command into arguments
args := strings.Fields(cmdStr)
args, err := shlex.Split(cmdStr)
if err != nil {
return nil, err
}
// Ensure the command is not empty
if len(args) == 0 {
+66 -16
View File
@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLoadConfig(t *testing.T) {
func TestConfig_Load(t *testing.T) {
// Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config")
if err != nil {
@@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) {
defer os.RemoveAll(tempDir)
tempFile := filepath.Join(tempDir, "config.yaml")
content := `models:
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
@@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) {
- "VAR1=value1"
- "VAR2=value2"
checkEndpoint: "/health"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "m2"
checkEndpoint: "/"
healthCheckTimeout: 15
profiles:
test:
- model1
- model2
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -50,14 +61,33 @@ healthCheckTimeout: 15
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
},
"model2": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: nil,
CheckEndpoint: "/",
},
},
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
}
assert.Equal(t, expected, config)
realname, found := config.RealModelName("m1")
assert.True(t, found)
assert.Equal(t, "model1", realname)
}
func TestModelConfigSanitizedCommand(t *testing.T) {
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
@@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) {
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestFindConfig(t *testing.T) {
func TestConfig_FindConfig(t *testing.T) {
// TODO?
// make make this shared between the different tests
config := &Config{
Models: map[string]ModelConfig{
"model1": {
@@ -88,36 +121,53 @@ func TestFindConfig(t *testing.T) {
},
},
HealthCheckTimeout: 10,
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
}
// Test finding a model by its name
modelConfig, found := config.FindConfig("model1")
modelConfig, modelId, found := config.FindConfig("model1")
assert.True(t, found)
assert.Equal(t, "model1", modelId)
assert.Equal(t, config.Models["model1"], modelConfig)
// Test finding a model by its alias
modelConfig, found = config.FindConfig("m1")
modelConfig, modelId, found = config.FindConfig("m1")
assert.True(t, found)
assert.Equal(t, "model1", modelId)
assert.Equal(t, config.Models["model1"], modelConfig)
// Test finding a model that does not exist
modelConfig, found = config.FindConfig("model3")
modelConfig, modelId, found = config.FindConfig("model3")
assert.False(t, found)
assert.Equal(t, "", modelId)
assert.Equal(t, ModelConfig{}, modelConfig)
}
func TestSanitizeCommand(t *testing.T) {
// Test a simple command
args, err := SanitizeCommand("python model1.py")
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py"}, args)
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a command with spaces and newlines
args, err = SanitizeCommand(`python model1.py \
--arg1 value1 \
--arg2 value2`)
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
--arg3 123 \
--arg4 '"string in string"'
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
+58
View File
@@ -0,0 +1,58 @@
package proxy
import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/gin-gonic/gin"
)
var (
nextTestPort int = 12000
portMutex sync.Mutex
)
// Check if the binary exists
func TestMain(m *testing.M) {
binaryPath := getSimpleResponderPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
os.Exit(1)
}
gin.SetMode(gin.TestMode)
m.Run()
}
// Helper function to get the binary path
func getSimpleResponderPath() string {
goos := runtime.GOOS
goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
portMutex.Lock()
defer portMutex.Unlock()
port := nextTestPort
nextTestPort++
return getTestSimpleResponderConfigPort(expectedMessage, port)
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
binaryPath := getSimpleResponderPath()
// Create a process configuration
return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

+14
View File
@@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>llama-swap</title>
</head>
<body>
<h1>llama-swap</h1>
<p>
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
</p>
</body>
</html>
+145
View File
@@ -0,0 +1,145 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Logs</title>
<style>
body {
margin: 0;
height: 100vh;
display: flex;
flex-direction: column;
font-family: "Courier New", Courier, monospace;
}
#log-controls {
margin: 0.5em;
display: flex;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
}
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
#log-stream {
flex: 1;
margin: 0.5em;
padding: 1em;
background: #f4f4f4;
overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */
}
.regex-error {
background-color: #ff0000 !important;
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
background-color: #333;
color: #fff;
}
#log-stream {
background: #444;
color: #fff;
}
#log-controls input {
background: #555;
color: #fff;
border: 1px solid #777;
}
#log-controls button {
background: #555;
color: #fff;
border: 1px solid #777;
}
}
</style>
</head>
<body>
<pre id="log-stream">Waiting for logs...</pre>
<div id="log-controls">
<input type="text" id="filter-input" placeholder="regex filter">
<button id="clear-button">clear</button>
</div>
<script>
const logStream = document.getElementById('log-stream');
const filterInput = document.getElementById('filter-input');
var logData = "";
let regexFilter = null;
function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
}
}
// poor-ai's react ¯\_(ツ)_/¯
function render() {
if (regexFilter) {
const lines = logData.split('\n');
const filteredLines = lines.filter(line => {
return regexFilter === null || regexFilter.test(line);
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
}
logStream.scrollTop = logStream.scrollHeight;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
}
} else {
regexFilter = null;
}
render();
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
});
setupEventSource();
updateFilter();
</script>
</body>
</html>
+10
View File
@@ -0,0 +1,10 @@
package proxy
import "embed"
//go:embed html
var htmlFiles embed.FS
func getHTMLFile(path string) ([]byte, error) {
return htmlFiles.ReadFile("html/" + path)
}
+96
View File
@@ -0,0 +1,96 @@
package proxy
import (
"container/ring"
"io"
"os"
"sync"
)
type LogMonitor struct {
clients map[chan []byte]bool
mu sync.RWMutex
buffer *ring.Ring
bufferMu sync.RWMutex
// typically this can be os.Stdout
stdout io.Writer
}
func NewLogMonitor() *LogMonitor {
return NewLogMonitorWriter(os.Stdout)
}
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
}
}
func (w *LogMonitor) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
n, err = w.stdout.Write(p)
if err != nil {
return n, err
}
w.bufferMu.Lock()
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.buffer.Value = bufferCopy
w.buffer = w.buffer.Next()
w.bufferMu.Unlock()
w.broadcast(bufferCopy)
return n, nil
}
func (w *LogMonitor) GetHistory() []byte {
w.bufferMu.RLock()
defer w.bufferMu.RUnlock()
var history []byte
w.buffer.Do(func(p any) {
if p != nil {
if content, ok := p.([]byte); ok {
history = append(history, content...)
}
}
})
return history
}
func (w *LogMonitor) Subscribe() chan []byte {
w.mu.Lock()
defer w.mu.Unlock()
ch := make(chan []byte, 100)
w.clients[ch] = true
return ch
}
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.clients, ch)
close(ch)
}
func (w *LogMonitor) broadcast(msg []byte) {
w.mu.RLock()
defer w.mu.RUnlock()
for client := range w.clients {
select {
case client <- msg:
default:
// If client buffer is full, skip
}
}
}
+95
View File
@@ -0,0 +1,95 @@
package proxy
import (
"bytes"
"io"
"sync"
"testing"
)
func TestLogMonitor(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
// Test subscription
client1 := logMonitor.Subscribe()
client2 := logMonitor.Subscribe()
defer logMonitor.Unsubscribe(client1)
defer logMonitor.Unsubscribe(client2)
client1Messages := make([]byte, 0)
client2Messages := make([]byte, 0)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case data := <-client1:
client1Messages = append(client1Messages, data...)
case data := <-client2:
client2Messages = append(client2Messages, data...)
default:
return
}
}
}()
logMonitor.Write([]byte("1"))
logMonitor.Write([]byte("2"))
logMonitor.Write([]byte("3"))
// Wait for the goroutine to finish
wg.Wait()
// Check the buffer
expectedHistory := "123"
history := string(logMonitor.GetHistory())
if history != expectedHistory {
t.Errorf("Expected history: %s, got: %s", expectedHistory, history)
}
c1Data := string(client1Messages)
if c1Data != expectedHistory {
t.Errorf("Client1 expected %s, got: %s", expectedHistory, c1Data)
}
c2Data := string(client2Messages)
if c2Data != expectedHistory {
t.Errorf("Client2 expected %s, got: %s", expectedHistory, c2Data)
}
}
func TestWrite_ImmutableBuffer(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
// Prepare a message to write
msg := []byte("Hello, World!")
lenmsg := len(msg)
// Write the message to the LogMonitor
n, err := lm.Write(msg)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != lenmsg {
t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
}
// Change the original message
msg[0] = 'B' // This should not affect the buffer
// Get the history from the LogMonitor
history := lm.GetHistory()
// Check that the history contains the original message, not the modified one
expected := []byte("Hello, World!")
if !bytes.Equal(history, expected) {
t.Errorf("Expected history to be %q, got %q", expected, history)
}
}
-227
View File
@@ -1,227 +0,0 @@
package proxy
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
type ProxyManager struct {
sync.Mutex
config *Config
currentCmd *exec.Cmd
currentConfig ModelConfig
}
func New(config *Config) *ProxyManager {
return &ProxyManager{config: config}
}
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints
if r.URL.Path == "/v1/chat/completions" {
// extracts the `model` from json body
pm.proxyChatRequest(w, r)
} else {
pm.proxyRequest(w, r)
}
}
func (pm *ProxyManager) swapModel(requestedModel string) error {
pm.Lock()
defer pm.Unlock()
// find the model configuration matching requestedModel
modelConfig, found := pm.config.FindConfig(requestedModel)
if !found {
return fmt.Errorf("could not find configuration for %s", requestedModel)
}
// no need to swap llama.cpp instances
if pm.currentConfig.Cmd == modelConfig.Cmd {
return nil
}
// kill the current running one to swap it
if pm.currentCmd != nil {
pm.currentCmd.Process.Signal(syscall.SIGTERM)
// wait for it to end
pm.currentCmd.Process.Wait()
}
pm.currentConfig = modelConfig
args, err := modelConfig.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = modelConfig.Env
err = cmd.Start()
if err != nil {
return err
}
pm.currentCmd = cmd
if err := pm.checkHealthEndpoint(); err != nil {
return err
}
return nil
}
func (pm *ProxyManager) checkHealthEndpoint() error {
if pm.currentConfig.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
}
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := pm.currentConfig.Proxy
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(req.Context(), 250*time.Millisecond)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
if strings.Contains(err.Error(), "connection refused") {
// if TCP dial can't connect any HTTP response after 5 seconds
// exit quickly.
if time.Since(startTime) > 5*time.Second {
return fmt.Errorf("health check endpoint took more than 5 seconds to respond")
}
}
if time.Since(startTime) >= maxDuration {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if time.Since(startTime) >= maxDuration {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
}
}
func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
model, ok := requestBody["model"].(string)
if !ok {
http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest)
return
}
if err := pm.swapModel(model); err != nil {
http.Error(w, fmt.Sprintf("unable to swap to model: %s", err.Error()), http.StatusNotFound)
return
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
pm.proxyRequest(w, r)
}
func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
if pm.currentConfig.Proxy == "" {
http.Error(w, "No upstream proxy", http.StatusInternalServerError)
return
}
proxyTo := pm.currentConfig.Proxy
client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
http.Error(w, writeErr.Error(), http.StatusInternalServerError)
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
}
}
+337
View File
@@ -0,0 +1,337 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateReady ProcessState = ProcessState("ready")
StateFailed ProcessState = ProcessState("failed")
)
type Process struct {
sync.Mutex
ID string
config ModelConfig
cmd *exec.Cmd
logMonitor *LogMonitor
healthCheckTimeout int
lastRequestHandled time.Time
stateMutex sync.RWMutex
state ProcessState
inFlightRequests sync.WaitGroup
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
return &Process{
ID: ID,
config: config,
cmd: nil,
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
}
}
// start the process and returns when it is ready
func (p *Process) start() error {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state == StateReady {
return nil
}
if p.state == StateFailed {
return fmt.Errorf("process is in a failed state and can not be restarted")
}
args, err := p.config.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.logMonitor
p.cmd.Env = p.config.Env
err = p.cmd.Start()
if err != nil {
return err
}
// One of three things can happen at this stage:
// 1. The command exits unexpectedly
// 2. The health check fails
// 3. The health check passes
//
// only in the third case will the process be considered Ready to accept
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
defer cancelHealthCheck(nil) // clean up
cmdWaitChan := make(chan error, 1)
healthCheckChan := make(chan error, 1)
go func() {
// possible cmd exits early
cmdWaitChan <- p.cmd.Wait()
}()
go func() {
<-time.After(250 * time.Millisecond) // give process a bit of time to start
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
}()
select {
case err := <-cmdWaitChan:
p.state = StateFailed
if err != nil {
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
} else {
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
}
cancelHealthCheck(err)
return err
case err := <-healthCheckChan:
if err != nil {
p.state = StateFailed
return err
}
}
if p.config.UnloadAfter > 0 {
// start a goroutine to check every second if
// the process should be stopped
go func() {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) {
if p.state != StateReady {
return
}
// wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
}
}()
}
p.state = StateReady
return nil
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != StateReady {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() called but Process State is not READY\n")
return
}
if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n")
p.state = StateStopped
return
}
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
p.cmd.Process.Signal(syscall.SIGTERM)
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "XXX Process for %s timed out waiting to stop, sending SIGKILL to PID: %d\n", p.ID, p.cmd.Process.Pid)
p.cmd.Process.Kill()
p.cmd.Wait()
case err := <-sigtermNormal:
if err != nil {
if err.Error() != "wait: no child processes" {
// possible that simple-responder for testing is just not
// existing right, so suppress those errors.
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
}
}
}
p.state = StateStopped
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := client.Do(req)
ttl := (maxDuration - time.Since(startTime)).Seconds()
if err != nil {
// check if the context was cancelled
select {
case <-ctx.Done():
err := context.Cause(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
return err
}
default:
}
// wait a bit longer for TCP connection issues
if strings.Contains(err.Error(), "connection refused") {
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
time.Sleep(5 * time.Second)
} else {
time.Sleep(time.Second)
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
}
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
if p.CurrentState() != StateReady {
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError)
return
}
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
}
}
+192
View File
@@ -0,0 +1,192 @@
package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, logMonitor)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// process is automatically started
assert.Equal(t, StateStopped, process.CurrentState())
process.ProxyRequest(w, req)
assert.Equal(t, StateReady, process.CurrentState())
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage)
// Stop the process
process.Stop()
req = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
// Proxy the request
process.ProxyRequest(w, req)
// should have automatically started the process again
if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
}
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
config := ModelConfig{
Cmd: "nonexistent-command",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, NewLogMonitor())
defer process.Stop()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), "unable to start process")
}
func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() {
t.Skip("skipping long auto unload TTL test")
}
expectedMessage := "I_sense_imminent_danger"
config := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop()
// this should take 4 seconds
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
req2 := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
process.ProxyRequest(w, req1)
t.Log("sending slow first request (4 seconds)")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "1234")
assert.Equal(t, StateReady, process.CurrentState())
// ensure the TTL timeout does not race slow requests (see issue #25)
t.Log("sending second request (1 second)")
time.Sleep(time.Second)
w = httptest.NewRecorder()
process.ProxyRequest(w, req2)
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage)
assert.Equal(t, StateReady, process.CurrentState())
// wait 5 seconds
t.Log("sleep 5 seconds and check if unloaded")
time.Sleep(5 * time.Second)
assert.Equal(t, StateStopped, process.CurrentState())
}
func TestProcess_LowTTLValue(t *testing.T) {
if true { // change this code to run this ...
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
for i := 0; i < 100; i++ {
t.Logf("Waiting before sending request %d", i)
time.Sleep(1500 * time.Millisecond)
expected := fmt.Sprintf("echo=test_%d", i)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expected)
}
}
// issue #19
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
results := map[string]string{
"12345": "",
"abcde": "",
"fghij": "",
}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range results {
wg.Add(1)
go func(key string) {
defer wg.Done()
// send a request that should take 5 * 200ms (1 second) to complete
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
}
// stop the requests in the middle
go func() {
<-time.After(500 * time.Millisecond)
process.Stop()
}()
wg.Wait()
for key, result := range results {
assert.Equal(t, key, result)
}
}
+357
View File
@@ -0,0 +1,357 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
const (
PROFILE_SPLIT_CHAR = ":"
)
type ProxyManager struct {
sync.Mutex
config *Config
currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine
}
func New(config *Config) *ProxyManager {
pm := &ProxyManager{
config: config,
currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
}
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
}
// see: https://github.com/mostlygeek/llama-swap/issues/42
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
// Support embeddings
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
htmlData, err := getHTMLFile("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(htmlData)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return
}
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := getHTMLFile("favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
} else {
c.String(http.StatusInternalServerError, err.Error())
}
})
// Disable console color for testing
gin.DisableConsoleColor()
return pm
}
func (pm *ProxyManager) Run(addr ...string) error {
return pm.ginEngine.Run(addr...)
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
}
func (pm *ProxyManager) StopProcesses() {
pm.Lock()
defer pm.Unlock()
pm.stopProcesses()
}
// for internal usage
func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
}
for _, process := range pm.currentProcesses {
process.Stop()
}
pm.currentProcesses = make(map[string]*Process)
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, map[string]interface{}{
"id": id,
"object": "model",
"created": time.Now().Unix(),
"owned_by": "llama-swap",
})
}
// Set the Content-Type header to application/json
c.Header("Content-Type", "application/json")
if origin := c.Request.Header.Get("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
return
}
}
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
} else {
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
// Extract keys and sort them
var modelIDs []string
for modelID, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
c.String(http.StatusOK, html.String())
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
if process, err := pm.swapModel(model); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
acceptHeader := c.GetHeader("Accept")
if strings.Contains(acceptHeader, "application/json") {
c.JSON(statusCode, gin.H{"error": message})
} else {
c.String(statusCode, message)
}
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
+113
View File
@@ -0,0 +1,113 @@
package proxy
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
logsHTML, err := getHTMLFile("logs.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(logsHTML)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return
}
} else {
c.Header("Content-Type", "text/plain")
history := pm.logMonitor.GetHistory()
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
}
}
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
return
}
_, skipHistory := c.GetQuery("no-history")
// Send history first if not skipped
if !skipHistory {
history := pm.logMonitor.GetHistory()
if len(history) != 0 {
c.Writer.Write(history)
flusher.Flush()
}
}
// Stream new logs
for {
select {
case msg := <-ch:
_, err := c.Writer.Write(msg)
if err != nil {
// just break the loop if we can't write for some reason
return
}
flusher.Flush()
case <-notify:
return
}
}
}
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := pm.logMonitor.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
}
}
// Stream new logs
for {
select {
case msg := <-ch:
c.SSEvent("message", string(msg))
c.Writer.Flush()
case <-notify:
return
}
}
}
+256
View File
@@ -0,0 +1,256 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1, model2},
},
}
proxy := New(config)
defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
}
// When a request for a different model comes in ProxyManager should wait until
// the first request is complete before swapping. Both requests should complete
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
results := map[string]string{}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range config.Models {
wg.Add(1)
go func(key string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
<-time.After(time.Millisecond)
}
wg.Wait()
assert.Len(t, results, len(config.Models))
for key, result := range results {
assert.Equal(t, key, result)
}
}
func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
// Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Add("Origin", "i-am-the-origin")
w := httptest.NewRecorder()
// Call the listModelsHandler
proxy.HandlerFunc(w, req)
// Check the response status code
assert.Equal(t, http.StatusOK, w.Code)
// Check for Access-Control-Allow-Origin
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
// Parse the JSON response
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// Check the number of models returned
assert.Len(t, response.Data, 3)
// Check the details of each model
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
"model3": {},
}
for _, model := range response.Data {
modelID, ok := model["id"].(string)
assert.True(t, ok, "model ID should be a string")
_, exists := expectedModels[modelID]
assert.True(t, exists, "unexpected model ID: %s", modelID)
delete(expectedModels, modelID)
object, ok := model["object"].(string)
assert.True(t, ok, "object should be a string")
assert.Equal(t, "model", object)
created, ok := model["created"].(float64)
assert.True(t, ok, "created should be a number")
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
ownedBy, ok := model["owned_by"].(string)
assert.True(t, ok, "owned_by should be a string")
assert.Equal(t, "llama-swap", ownedBy)
}
// Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned")
}
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}