Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5fbd53c616 | |||
| 97dae50dc4 | |||
| cb978f760f | |||
| 387f0ef6c4 | |||
| 18c134624d | |||
| da2326bdc7 | |||
| da46545630 | |||
| 04b4760e7e | |||
| 9fc5d5b5eb | |||
| cf82b3c633 | |||
| e363f8f498 | |||
| c9629cf3a2 | |||
| 50426935a4 | |||
| 2fceb78e8d | |||
| 9a81c53664 | |||
| 716d37de82 | |||
| 73ad85ea69 | |||
| 533162ce6a | |||
| ba39ed4c18 | |||
| 21f54f96c2 | |||
| 7eec51f3f2 | |||
| 5021e0f299 | |||
| c9233d2c9a | |||
| a33ac6f8fb | |||
| 401aa88949 | |||
| e9e88fd229 | |||
| c3b4bb1684 | |||
| e5c909ddf7 | |||
| 36a31f450f | |||
| a8e5ee13b9 | |||
| 5944a86e86 | |||
| 63d4a7d0eb | |||
| f45469f7ff | |||
| 34f9fd7340 | |||
| 8448efa7fc |
@@ -10,6 +10,9 @@ clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -short -v ./proxy
|
||||
|
||||
test-all:
|
||||
go test -v ./proxy
|
||||
|
||||
# Build OSX binary
|
||||
@@ -22,10 +25,11 @@ linux:
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
# for testing things
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
go build -o $(BUILD_DIR)/simple-responder misc/simple-responder/simple-responder.go
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
|
||||
@@ -2,17 +2,35 @@
|
||||
|
||||

|
||||
|
||||
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models on demand. So let's swap the server on demand instead!
|
||||
# Introduction
|
||||
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||
|
||||
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
|
||||
Features:
|
||||
|
||||
- ✅ easy to deploy: single binary with no dependencies
|
||||
- ✅ full control over llama-server's startup settings
|
||||
- ✅ ❤️ for users who are rely on llama.cpp for LLM inference
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Single yaml configuration file
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
|
||||
## Releases
|
||||
|
||||
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration purposefully simple.
|
||||
llama-swap's configuration is purposefully simple.
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
@@ -24,21 +42,25 @@ models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# where to reach the server started by cmd
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases model names to use this configuration for
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# wait for this path to return an HTTP 200 before serving requests
|
||||
# defaults to /health to match llama.cpp
|
||||
#
|
||||
# use "none" to skip endpoint checking. This may cause requests to fail
|
||||
# until the server is ready
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||
# until the model is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# automatically unload the model after this many seconds
|
||||
# ttl values must be a value greater than 0
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
@@ -48,10 +70,22 @@ models:
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
```
|
||||
|
||||
More [examples](examples/README.md) are available for different use cases.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
@@ -61,22 +95,22 @@ models:
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
The `/logs` endpoint is available to monitor what llama-swap is doing. It will send the last 10KB of logs. Useful for monitoring the output of llama-server. It also supports streaming of logs.
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
|
||||
Usage:
|
||||
Of course, CLI access is also supported:
|
||||
|
||||
```
|
||||
# basic, sends up to the last 10KB of logs
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# add `stream` to stream new logs as they come in
|
||||
curl -Ns 'http://host/logs?stream'
|
||||
# streams logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
|
||||
# add `skip` to skip history (only useful if used with stream)
|
||||
curl -Ns 'http://host/logs?stream&skip'
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# will output nothing :)
|
||||
curl -Ns 'http://host/logs?skip'
|
||||
# skips history and just streams new log entries
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Systemd Unit Files
|
||||
@@ -103,9 +137,3 @@ StartLimitInterval=30
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Building from Source
|
||||
|
||||
1. Install golang for your system
|
||||
1. run `make clean all`
|
||||
1. binaries will be built into `build/` directory
|
||||
|
||||
+20
-8
@@ -1,14 +1,14 @@
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
healthCheckTimeout: 15
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
models/llama-server-osx
|
||||
--port 8999
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
--port 9001
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
proxy: http://127.0.0.1:9001
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
@@ -17,9 +17,12 @@ models:
|
||||
# check this path for a HTTP 200 response for the server to be ready
|
||||
checkEndpoint: /health
|
||||
|
||||
# unload model after 5 seconds
|
||||
ttl: 5
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9002
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
|
||||
@@ -35,7 +38,16 @@ models:
|
||||
# until the upstream server is ready for traffic
|
||||
checkEndpoint: none
|
||||
|
||||
# don't use this, just for testing if things are broken
|
||||
# don't use these, just for testing if things are broken
|
||||
"broken":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
proxy: http://127.0.0.1:8999
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
@@ -0,0 +1,6 @@
|
||||
# Example Configs and Use Cases
|
||||
|
||||
A collections of usecases and examples for getting the most out of llama-swap.
|
||||
|
||||
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
@@ -0,0 +1,123 @@
|
||||
# Optimizing Code Generation with llama-swap
|
||||
|
||||
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
|
||||
|
||||
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
|
||||
|
||||
**Benchmark Scenarios**
|
||||
|
||||
Three scenarios are tested:
|
||||
|
||||
- 3090-only: Just the main model on the 3090
|
||||
- 3090-with-draft: the main and draft models on the 3090
|
||||
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
|
||||
|
||||
**Available Devices**
|
||||
|
||||
Use the following command to list available devices IDs for the configuration:
|
||||
|
||||
```
|
||||
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||
ggml_cuda_init: found 4 CUDA devices:
|
||||
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Available devices:
|
||||
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
|
||||
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
|
||||
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||
```
|
||||
|
||||
**Configuration**
|
||||
|
||||
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"3090-only":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn
|
||||
--slots
|
||||
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--ctx-size 32768
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
|
||||
"3090-with-draft":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
# --ctx-size 28500 max that can fit on 3090 after draft model
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn
|
||||
--slots
|
||||
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA0
|
||||
|
||||
--ctx-size 28500
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
|
||||
"3090-P40-draft":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-f3252055
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--device CUDA0
|
||||
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA1
|
||||
|
||||
--ctx-size 32768
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
```
|
||||
|
||||
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
|
||||
|
||||
|
||||
**Running the Benchmark**
|
||||
|
||||
To run the benchmark, execute the following commands:
|
||||
|
||||
1. `llama-swap -config benchmark-config.yaml`
|
||||
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
|
||||
|
||||
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
|
||||
|
||||
**Results (tokens/second)**
|
||||
|
||||
| model | python | typescript | swift |
|
||||
|-----------------|--------|------------|-------|
|
||||
| 3090-only | 34.03 | 34.01 | 34.01 |
|
||||
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
|
||||
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
|
||||
|
||||
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
|
||||
|
||||
Happy coding!
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
|
||||
# It was created to test the effects of speculative decoding and the various draft settings on performance.
|
||||
#
|
||||
# Writing code with a low temperature seems to provide fairly consistent logic.
|
||||
#
|
||||
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
|
||||
# Example: ./benchmark.sh http://localhost:8080 model1 model2
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
echo "Usage: $0 <url> <model1> [model2 ...]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
url=$1; shift
|
||||
|
||||
echo "model,python,typescript,swift"
|
||||
|
||||
for model in "$@"; do
|
||||
|
||||
echo -n "$model,"
|
||||
|
||||
for lang in "python" "typescript" "swift"; do
|
||||
# expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548
|
||||
# (Dec 3rd/2024)
|
||||
time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second)
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
time="error"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$lang" != "swift" ]; then
|
||||
printf "%0.2f tps," $time
|
||||
else
|
||||
printf "%0.2f tps\n" $time
|
||||
fi
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,124 @@
|
||||
# Speculative Decoding
|
||||
|
||||
Speculative decoding can significantly improve the tokens per second. However, this comes at the cost of increased VRAM usage for the draft model. The examples provided are based on a server with three P40s and one 3090.
|
||||
|
||||
## Coding Use Case
|
||||
|
||||
This example uses Qwen2.5 Coder 32B with the 0.5B model as a draft. A quantization of Q8_0 was chosen for the draft model, as quantization has a greater impact on smaller models.
|
||||
|
||||
The models used are:
|
||||
|
||||
* [Bartowski Qwen2.5-Coder-32B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF)
|
||||
* [Bartowski Qwen2.5-Coder-0.5B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF)
|
||||
|
||||
The llama-swap configuration is as follows:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen-coder-32b-q4":
|
||||
# main model on 3090, draft on P40 #1
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-be0e35
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
-ngl 99
|
||||
--ctx-size 19000
|
||||
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 4
|
||||
--draft-p-min 0.4
|
||||
--device CUDA0
|
||||
--device-draft CUDA1
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
```
|
||||
|
||||
In this configuration, two GPUs are used: a 3090 (CUDA0) for the main model and a P40 (CUDA1) for the draft model. Although both models can fit on the 3090, relocating the draft model to the P40 freed up space for a larger context size. Despite the P40 being about 1/3rd the speed of the 3090, the small model still improved tokens per second.
|
||||
|
||||
Multiple tests were run with various parameters, and the fastest result was chosen for the configuration. In all tests, the 0.5B model produced the largest improvements to tokens per second.
|
||||
|
||||
Baseline: 33.92 tokens/second on 3090 without a draft model.
|
||||
|
||||
| draft-max | draft-min | draft-p-min | python | TS | swift |
|
||||
|-----------|-----------|-------------|--------|----|-------|
|
||||
| 16 | 1 | 0.9 | 71.64 | 55.55 | 48.06 |
|
||||
| 16 | 1 | 0.4 | 83.21 | 58.55 | 45.50 |
|
||||
| 16 | 1 | 0.1 | 79.72 | 55.66 | 43.94 |
|
||||
| 16 | 2 | 0.9 | 68.47 | 55.13 | 43.12 |
|
||||
| 16 | 2 | 0.4 | 82.82 | 57.42 | 48.83 |
|
||||
| 16 | 2 | 0.1 | 81.68 | 51.37 | 45.72 |
|
||||
| 16 | 4 | 0.9 | 66.44 | 48.49 | 42.40 |
|
||||
| 16 | 4 | 0.4 | _83.62_ (fastest)| _58.29_ | _50.17_ |
|
||||
| 16 | 4 | 0.1 | 82.46 | 51.45 | 40.71 |
|
||||
| 8 | 1 | 0.4 | 67.07 | 55.17 | 48.46 |
|
||||
| 4 | 1 | 0.4 | 50.13 | 44.96 | 40.79 |
|
||||
|
||||
The test script can be found in this [gist](https://gist.github.com/mostlygeek/da429769796ac8a111142e75660820f1). It is a simple curl script that prompts generating a snake game in Python, TypeScript, or Swift. Evaluation metrics were pulled from llama.cpp's logs.
|
||||
|
||||
```bash
|
||||
for lang in "python" "typescript" "swift"; do
|
||||
echo "Generating Snake Game in $lang using $model"
|
||||
curl -s --url http://localhost:8080/v1/chat/completions -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}" > /dev/null
|
||||
done
|
||||
```
|
||||
|
||||
Python consistently outperformed Swift in all tests, likely due to the 0.5B draft model being more proficient in generating Python code accepted by the larger 32B model.
|
||||
|
||||
## Chat
|
||||
|
||||
This configuration is for a regular chat use case. It produces approximately 13 tokens/second in typical use, up from ~9 tokens/second with only the 3xP40s. This is great news for P40 owners.
|
||||
|
||||
The models used are:
|
||||
|
||||
* [Bartowski Meta-Llama-3.1-70B-Instruct-GGUF](https://huggingface.co/bartowski/Meta-Llama-3.1-70B-Instruct-GGUF)
|
||||
* [Bartowski Llama-3.2-3B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF)
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"llama-70B":
|
||||
cmd: >
|
||||
/mnt/nvme/llama-server/llama-server-be0e35
|
||||
--host 127.0.0.1 --port 9602
|
||||
--flash-attn --metrics
|
||||
--split-mode row
|
||||
--ctx-size 80000
|
||||
--model /mnt/nvme/models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf
|
||||
-ngl 99
|
||||
--model-draft /mnt/nvme/models/Llama-3.2-3B-Instruct-Q4_K_M.gguf
|
||||
-ngld 99
|
||||
--draft-max 16
|
||||
--draft-min 1
|
||||
--draft-p-min 0.4
|
||||
--device-draft CUDA0
|
||||
--tensor-split 0,1,1,1
|
||||
```
|
||||
|
||||
In this configuration, Llama-3.1-70B is split across three P40s, and Llama-3.2-3B is on the 3090.
|
||||
|
||||
Some flags deserve further explanation:
|
||||
|
||||
* `--split-mode row` - increases inference speeds using multiple P40s by about 30%. This is a P40-specific feature.
|
||||
* `--tensor-split 0,1,1,1` - controls how the main model is split across the GPUs. This means 0% on the 3090 and an even split across the P40s. A value of `--tensor-split 0,5,4,1` would mean 0% on the 3090, 50%, 40%, and 10% respectively across the other P40s. However, this would exceed the available VRAM.
|
||||
* `--ctx-size 80000` - the maximum context size that can fit in the remaining VRAM.
|
||||
|
||||
## What is CUDA0, CUDA1, CUDA2, CUDA3?
|
||||
|
||||
These devices are the IDs used by llama.cpp.
|
||||
|
||||
```bash
|
||||
$ ./llama-server --list-devices
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||
ggml_cuda_init: found 4 CUDA devices:
|
||||
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||
Available devices:
|
||||
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 23892 MiB free)
|
||||
CUDA1: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||
CUDA2: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||
CUDA3: Tesla P40 (24438 MiB, 24290 MiB free)
|
||||
```
|
||||
@@ -8,6 +8,33 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,10 +1,85 @@
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
+9
-5
@@ -3,9 +3,9 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
@@ -22,12 +22,16 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
http.HandleFunc("/", proxyManager.HandleFunc)
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := http.ListenAndServe(*listenStr, nil); err != nil {
|
||||
fmt.Printf("Error starting server: %v\n", err)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,47 +3,137 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set the header to text/plain
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
// Create a new Gin router
|
||||
r := gin.New()
|
||||
|
||||
fmt.Fprintln(w, *responseMessage)
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
if echo == "" {
|
||||
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
if delay == "" {
|
||||
delay = "100ms"
|
||||
}
|
||||
|
||||
t, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/plain")
|
||||
for _, char := range echo {
|
||||
c.Writer.Write([]byte(string(char)))
|
||||
c.Writer.Flush()
|
||||
|
||||
// wait
|
||||
<-time.After(t)
|
||||
}
|
||||
})
|
||||
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/env", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
|
||||
// Get environment variables
|
||||
envVars := os.Environ()
|
||||
|
||||
// Write each environment variable to the response
|
||||
for _, envVar := range envVars {
|
||||
fmt.Fprintln(w, envVar)
|
||||
c.String(200, envVar)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up the /health endpoint handler function
|
||||
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
response := `{"status": "ok"}`
|
||||
w.Write([]byte(response))
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
address := ":" + *port // Address with the specified port
|
||||
fmt.Printf("Server is listening on port %s\n", *port)
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
// Start the server and log any error if it occurs
|
||||
if err := http.ListenAndServe(address, nil); err != nil {
|
||||
fmt.Printf("Error starting server: %s\n", err)
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
Handler: r.Handler(),
|
||||
}
|
||||
|
||||
// Disable logging if the --silent flag is set
|
||||
if *silent {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("simple-responder err: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
quit := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
|
||||
+32
-15
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -14,6 +15,7 @@ type ModelConfig struct {
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
@@ -21,26 +23,30 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, bool) {
|
||||
modelConfig, found := c.Models[modelName]
|
||||
if found {
|
||||
return modelConfig, true
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// Search through aliases to find the right config
|
||||
for _, config := range c.Models {
|
||||
for _, alias := range config.Aliases {
|
||||
if alias == modelName {
|
||||
return config, true
|
||||
}
|
||||
}
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
@@ -59,6 +65,14 @@ func LoadConfig(path string) (*Config, error) {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
@@ -68,7 +82,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
|
||||
// Split the command into arguments
|
||||
args := strings.Fields(cmdStr)
|
||||
args, err := shlex.Split(cmdStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
|
||||
+66
-16
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
@@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) {
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `models:
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
@@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) {
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
@@ -50,14 +61,33 @@ healthCheckTimeout: 15
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
|
||||
func TestModelConfigSanitizedCommand(t *testing.T) {
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
@@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) {
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestFindConfig(t *testing.T) {
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
// make make this shared between the different tests
|
||||
config := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
@@ -88,36 +121,53 @@ func TestFindConfig(t *testing.T) {
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 10,
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
// Test finding a model by its name
|
||||
modelConfig, found := config.FindConfig("model1")
|
||||
modelConfig, modelId, found := config.FindConfig("model1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model by its alias
|
||||
modelConfig, found = config.FindConfig("m1")
|
||||
modelConfig, modelId, found = config.FindConfig("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model that does not exist
|
||||
modelConfig, found = config.FindConfig("model3")
|
||||
modelConfig, modelId, found = config.FindConfig("model3")
|
||||
assert.False(t, found)
|
||||
assert.Equal(t, "", modelId)
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestSanitizeCommand(t *testing.T) {
|
||||
// Test a simple command
|
||||
args, err := SanitizeCommand("python model1.py")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py"}, args)
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err = SanitizeCommand(`python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`)
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
--arg3 123 \
|
||||
--arg4 '"string in string"'
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
func TestMain(m *testing.M) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
// Helper function to get the binary path
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
|
||||
// Create a process configuration
|
||||
return ModelConfig{
|
||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Logs</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 1em;
|
||||
padding: 10px;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...
|
||||
</pre>
|
||||
|
||||
<script>
|
||||
// Establish an EventSource connection to the SSE endpoint
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
// Append the new log message to the <pre> element
|
||||
const logStream = document.getElementById('log-stream');
|
||||
|
||||
logStream.textContent += event.data;
|
||||
|
||||
// Auto-scroll to the bottom
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
console.error("EventSource failed:", err);
|
||||
};
|
||||
} else {
|
||||
console.error("SSE not supported by this browser.");
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
+8
-2
@@ -30,13 +30,19 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, err = w.stdout.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
w.buffer.Value = p
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
@@ -49,7 +55,7 @@ func (w *LogMonitor) GetHistory() []byte {
|
||||
defer w.bufferMu.RUnlock()
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p interface{}) {
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -61,3 +62,34 @@ func TestLogMonitor(t *testing.T) {
|
||||
t.Errorf("Client2 expected %s, got: %s", expectedHistory, c2Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||
// Create a new LogMonitor instance
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Prepare a message to write
|
||||
msg := []byte("Hello, World!")
|
||||
lenmsg := len(msg)
|
||||
|
||||
// Write the message to the LogMonitor
|
||||
n, err := lm.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
if n != lenmsg {
|
||||
t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
|
||||
}
|
||||
|
||||
// Change the original message
|
||||
msg[0] = 'B' // This should not affect the buffer
|
||||
|
||||
// Get the history from the LogMonitor
|
||||
history := lm.GetHistory()
|
||||
|
||||
// Check that the history contains the original message, not the modified one
|
||||
expected := []byte("Hello, World!")
|
||||
if !bytes.Equal(history, expected) {
|
||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentCmd *exec.Cmd
|
||||
currentConfig ModelConfig
|
||||
logMonitor *LogMonitor
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
return &ProxyManager{config: config, logMonitor: NewLogMonitor()}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints
|
||||
if r.URL.Path == "/v1/chat/completions" {
|
||||
// extracts the `model` from json body
|
||||
pm.proxyChatRequest(w, r)
|
||||
} else if r.URL.Path == "/v1/models" {
|
||||
pm.listModels(w, r)
|
||||
} else if r.URL.Path == "/logs" {
|
||||
pm.streamLogs(w, r)
|
||||
} else {
|
||||
pm.proxyRequest(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := r.Context().Done()
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
skipHistory := r.URL.Query().Has("skip")
|
||||
if !skipHistory {
|
||||
// Send history first
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
w.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if !r.URL.Query().Has("stream") {
|
||||
return
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
w.Write(msg)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModels(w http.ResponseWriter, _ *http.Request) {
|
||||
data := []interface{}{}
|
||||
for id := range pm.config.Models {
|
||||
data = append(data, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
http.Error(w, "Error encoding JSON", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// find the model configuration matching requestedModel
|
||||
modelConfig, found := pm.config.FindConfig(requestedModel)
|
||||
if !found {
|
||||
return fmt.Errorf("could not find configuration for %s", requestedModel)
|
||||
}
|
||||
|
||||
// no need to swap llama.cpp instances
|
||||
if pm.currentConfig.Cmd == modelConfig.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// kill the current running one to swap it
|
||||
if pm.currentCmd != nil {
|
||||
pm.currentCmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
// wait for it to end
|
||||
pm.currentCmd.Process.Wait()
|
||||
}
|
||||
|
||||
pm.currentConfig = modelConfig
|
||||
|
||||
args, err := modelConfig.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
|
||||
// logMonitor only writes to stdout
|
||||
// so the upstream's stderr will go to os.Stdout
|
||||
cmd.Stdout = pm.logMonitor
|
||||
cmd.Stderr = pm.logMonitor
|
||||
|
||||
cmd.Env = modelConfig.Env
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pm.currentCmd = cmd
|
||||
|
||||
if err := pm.checkHealthEndpoint(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) checkHealthEndpoint() error {
|
||||
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(req.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
|
||||
// if TCP dial can't connect any HTTP response after 5 seconds
|
||||
// exit quickly.
|
||||
if time.Since(startTime) > 5*time.Second {
|
||||
return fmt.Errorf("health check endpoint took more than 5 seconds to respond")
|
||||
}
|
||||
}
|
||||
|
||||
if time.Since(startTime) >= maxDuration {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
if time.Since(startTime) >= maxDuration {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := pm.swapModel(model); err != nil {
|
||||
http.Error(w, fmt.Sprintf("unable to swap to model: %s", err.Error()), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
pm.proxyRequest(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
http.Error(w, "No upstream proxy", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
http.Error(w, writeErr.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,332 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
|
||||
const (
|
||||
StateStopped ProcessState = ProcessState("stopped")
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
sync.Mutex
|
||||
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
healthCheckTimeout int
|
||||
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
state: StateStopped,
|
||||
}
|
||||
}
|
||||
|
||||
// start the process and returns when it is ready
|
||||
func (p *Process) start() error {
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state == StateReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.state == StateFailed {
|
||||
return fmt.Errorf("process is in a failed state and can not be restarted")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.logMonitor
|
||||
p.cmd.Stderr = p.logMonitor
|
||||
p.cmd.Env = p.config.Env
|
||||
|
||||
err = p.cmd.Start()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
// 3. The health check passes
|
||||
//
|
||||
// only in the third case will the process be considered Ready to accept
|
||||
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
||||
defer cancelHealthCheck(nil) // clean up
|
||||
cmdWaitChan := make(chan error, 1)
|
||||
healthCheckChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// possible cmd exits early
|
||||
cmdWaitChan <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdWaitChan:
|
||||
p.state = StateFailed
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
||||
} else {
|
||||
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
||||
}
|
||||
cancelHealthCheck(err)
|
||||
return err
|
||||
case err := <-healthCheckChan:
|
||||
if err != nil {
|
||||
p.state = StateFailed
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if p.config.UnloadAfter > 0 {
|
||||
// start a goroutine to check every second if
|
||||
// the process should be stopped
|
||||
go func() {
|
||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||
|
||||
for range time.Tick(time.Second) {
|
||||
// wait for all inflight requests to complete and ticker
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
p.state = StateReady
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
// this situation should never happen... but if it does just update the state
|
||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
||||
p.state = StateStopped
|
||||
return
|
||||
}
|
||||
|
||||
// Pretty sure this stopping code needs some work for windows and
|
||||
// will be a source of pain in the future.
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- p.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
p.cmd.Wait()
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
if err.Error() != "wait: no child processes" {
|
||||
// possible that simple-responder for testing is just not
|
||||
// existing right, so suppress those errors.
|
||||
fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
p.state = StateStopped
|
||||
}
|
||||
|
||||
func (p *Process) CurrentState() ProcessState {
|
||||
p.stateMutex.RLock()
|
||||
defer p.stateMutex.RUnlock()
|
||||
return p.state
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
if p.CurrentState() != StateReady {
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// process is automatically started
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
|
||||
// Stop the process
|
||||
process.Stop()
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
// Proxy the request
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// should have automatically started the process again
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
}
|
||||
|
||||
// test that the process unloads after the TTL
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
|
||||
process.ProxyRequest(w, req1)
|
||||
|
||||
t.Log("sending slow first request (4 seconds)")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "1234")
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// ensure the TTL timeout does not race slow requests (see issue #25)
|
||||
t.Log("sending second request (1 second)")
|
||||
time.Sleep(time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req2)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// wait 5 seconds
|
||||
t.Log("sleep 5 seconds and check if unloaded")
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
// issue #19
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
"12345": "",
|
||||
"abcde": "",
|
||||
"fghij": "",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range results {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request that should take 5 * 200ms (1 second) to complete
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
|
||||
}(key)
|
||||
}
|
||||
|
||||
// stop the requests in the middle
|
||||
go func() {
|
||||
<-time.After(500 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
PROFILE_SPLIT_CHAR = ":"
|
||||
)
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentProcesses map[string]*Process
|
||||
logMonitor *LogMonitor
|
||||
ginEngine *gin.Engine
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
currentProcesses: make(map[string]*Process),
|
||||
logMonitor: NewLogMonitor(),
|
||||
ginEngine: gin.New(),
|
||||
}
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyChatRequestHandler)
|
||||
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyChatRequestHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||
|
||||
pm.ginEngine.NoRoute(pm.proxyNoRouteHandler)
|
||||
|
||||
// Disable console color for testing
|
||||
gin.DisableConsoleColor()
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Run(addr ...string) error {
|
||||
return pm.ginEngine.Run(addr...)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) StopProcesses() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
pm.stopProcesses()
|
||||
}
|
||||
|
||||
// for internal usage
|
||||
func (pm *ProxyManager) stopProcesses() {
|
||||
if len(pm.currentProcesses) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, process := range pm.currentProcesses {
|
||||
process.Stop()
|
||||
}
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id := range pm.config.Models {
|
||||
data = append(data, map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"owned_by": "llama-swap",
|
||||
})
|
||||
}
|
||||
|
||||
// Set the Content-Type header to application/json
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a /
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
}
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(modelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
// exit early when already running, otherwise stop everything and swap
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// stop all running models
|
||||
pm.stopProcesses()
|
||||
|
||||
if profileName == "" {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
} else {
|
||||
for _, modelName := range pm.config.Profiles[profileName] {
|
||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestedProcessKey should exist due to swap
|
||||
return pm.currentProcesses[requestedProcessKey], nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
||||
return
|
||||
}
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
||||
return
|
||||
}
|
||||
model, ok := requestBody["model"].(string)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("missing or invalid 'model' key"))
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(model); err != nil {
|
||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
||||
// since maps are unordered, just use the first available process if one exists
|
||||
for _, process := range pm.currentProcesses {
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
return
|
||||
}
|
||||
|
||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//go:embed html/logs.html
|
||||
var logsHTML []byte
|
||||
|
||||
// make sure embed is kept there by the IDE auto-package importer
|
||||
var _ = embed.FS{}
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
_, err := c.Writer.Write(logsHTML)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported"))
|
||||
return
|
||||
}
|
||||
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
_, err := c.Writer.Write(msg)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
c.SSEvent("message", string(msg))
|
||||
c.Writer.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||
|
||||
}
|
||||
|
||||
// make sure there's only one loaded model
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileModel1 := ProcessKeyName("test", model1)
|
||||
profileModel2 := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1, model2},
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
for modelID, requestedModel := range map[string]string{
|
||||
"model1": profileModel1,
|
||||
"model2": profileModel2,
|
||||
} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelID)
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Len(t, proxy.currentProcesses, 2)
|
||||
_, exists := proxy.currentProcesses[profileModel1]
|
||||
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||
|
||||
_, exists = proxy.currentProcesses[profileModel2]
|
||||
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||
}
|
||||
|
||||
// When a request for a different model comes in ProxyManager should wait until
|
||||
// the first request is complete before swapping. Both requests should complete
|
||||
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
results := map[string]string{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
for key := range config.Models {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
|
||||
results[key] = w.Body.String()
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Len(t, results, len(config.Models))
|
||||
|
||||
for key, result := range results {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Add("Origin", "i-am-the-origin")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
// Check the response status code
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check for Access-Control-Allow-Origin
|
||||
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
// Parse the JSON response
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
}
|
||||
|
||||
for _, model := range response.Data {
|
||||
modelID, ok := model["id"].(string)
|
||||
assert.True(t, ok, "model ID should be a string")
|
||||
_, exists := expectedModels[modelID]
|
||||
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
||||
delete(expectedModels, modelID)
|
||||
|
||||
object, ok := model["object"].(string)
|
||||
assert.True(t, ok, "object should be a string")
|
||||
assert.Equal(t, "model", object)
|
||||
|
||||
created, ok := model["created"].(float64)
|
||||
assert.True(t, ok, "created should be a number")
|
||||
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
||||
|
||||
ownedBy, ok := model["owned_by"].(string)
|
||||
assert.True(t, ok, "owned_by should be a string")
|
||||
assert.Equal(t, "llama-swap", ownedBy)
|
||||
}
|
||||
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
Reference in New Issue
Block a user