Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 97dae50dc4 | |||
| cb978f760f | |||
| 387f0ef6c4 | |||
| 18c134624d | |||
| da2326bdc7 | |||
| da46545630 | |||
| 04b4760e7e | |||
| 9fc5d5b5eb |
@@ -2,20 +2,32 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
llama-swap is a golang server that automatically swaps the llama.cpp server on demand. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
# Introduction
|
||||||
|
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
|
|
||||||
- ✅ Easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
- ✅ Single yaml configuration file
|
- ✅ Single yaml configuration file
|
||||||
- ✅ Automatic switching between models
|
- ✅ On-demand model switching
|
||||||
- ✅ Full control over llama.cpp server settings per model
|
- ✅ Full control over server settings per model
|
||||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||||
- ✅ Multiple GPU support
|
- ✅ Multiple GPU support
|
||||||
- ✅ Run multiple models at once with `profiles`
|
- ✅ Run multiple models at once with `profiles`
|
||||||
- ✅ Remote log monitoring at `/log`
|
- ✅ Remote log monitoring at `/log`
|
||||||
- ✅ Automatic unloading of models from GPUs after timeout
|
- ✅ Automatic unloading of models from GPUs after timeout
|
||||||
|
|
||||||
|
## Releases
|
||||||
|
|
||||||
|
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
||||||
|
|
||||||
|
### Building from source
|
||||||
|
|
||||||
|
1. Install golang for your system
|
||||||
|
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||||
|
1. `make clean all`
|
||||||
|
1. Binaries will be in `build/` subdirectory
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration is purposefully simple.
|
llama-swap's configuration is purposefully simple.
|
||||||
@@ -64,7 +76,7 @@ models:
|
|||||||
#
|
#
|
||||||
# Tips:
|
# Tips:
|
||||||
# - each model must be listening on a unique address and port
|
# - each model must be listening on a unique address and port
|
||||||
# - the model name is in this format: "profile_name/model", like "coding/qwen"
|
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||||
# - the profile will load and unload all models in the profile at the same time
|
# - the profile will load and unload all models in the profile at the same time
|
||||||
profiles:
|
profiles:
|
||||||
coding:
|
coding:
|
||||||
@@ -83,22 +95,22 @@ More [examples](examples/README.md) are available for different use cases.
|
|||||||
|
|
||||||
## Monitoring Logs
|
## Monitoring Logs
|
||||||
|
|
||||||
The `/logs` endpoint is available to monitor what llama-swap is doing. It will send the last 10KB of logs. Useful for monitoring the output of llama-server. It also supports streaming of logs.
|
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||||
|
|
||||||
Usage:
|
Of course, CLI access is also supported:
|
||||||
|
|
||||||
```
|
```
|
||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
curl http://host/logs'
|
||||||
|
|
||||||
# streams logs using chunk encoding
|
# streams logs
|
||||||
curl -Ns 'http://host/logs/stream'
|
curl -Ns 'http://host/logs/stream'
|
||||||
|
|
||||||
|
# stream and filter logs with linux pipes
|
||||||
|
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||||
|
|
||||||
# skips history and just streams new log entries
|
# skips history and just streams new log entries
|
||||||
curl -Ns 'http://host/logs/stream?no-history'
|
curl -Ns 'http://host/logs/stream?no-history'
|
||||||
|
|
||||||
# streams logs using Server Sent Events
|
|
||||||
curl -Ns 'http://host/logs/streamSSE'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Systemd Unit Files
|
## Systemd Unit Files
|
||||||
@@ -125,9 +137,3 @@ StartLimitInterval=30
|
|||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
## Building from Source
|
|
||||||
|
|
||||||
1. Install golang for your system
|
|
||||||
1. run `make clean all`
|
|
||||||
1. binaries will be built into `build/` directory
|
|
||||||
|
|||||||
+3
-6
@@ -1,9 +1,6 @@
|
|||||||
# Example Configurations
|
# Example Configs and Use Cases
|
||||||
|
|
||||||
Learning by example is best.
|
A collections of usecases and examples for getting the most out of llama-swap.
|
||||||
|
|
||||||
Here in the `examples/` folder are llama-swap configurations that can be used on your local LLM server.
|
|
||||||
|
|
||||||
## List
|
|
||||||
|
|
||||||
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
* [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
|
* [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
# Optimizing Code Generation with llama-swap
|
||||||
|
|
||||||
|
Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration.
|
||||||
|
|
||||||
|
The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40.
|
||||||
|
|
||||||
|
**Benchmark Scenarios**
|
||||||
|
|
||||||
|
Three scenarios are tested:
|
||||||
|
|
||||||
|
- 3090-only: Just the main model on the 3090
|
||||||
|
- 3090-with-draft: the main and draft models on the 3090
|
||||||
|
- 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40
|
||||||
|
|
||||||
|
**Available Devices**
|
||||||
|
|
||||||
|
Use the following command to list available devices IDs for the configuration:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||||
|
ggml_cuda_init: found 4 CUDA devices:
|
||||||
|
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
|
||||||
|
Device 1: Tesla P40, compute capability 6.1, VMM: yes
|
||||||
|
Device 2: Tesla P40, compute capability 6.1, VMM: yes
|
||||||
|
Device 3: Tesla P40, compute capability 6.1, VMM: yes
|
||||||
|
Available devices:
|
||||||
|
CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free)
|
||||||
|
CUDA1: Tesla P40 (24438 MiB, 22942 MiB free)
|
||||||
|
CUDA2: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||||
|
CUDA3: Tesla P40 (24438 MiB, 24144 MiB free)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Configuration**
|
||||||
|
|
||||||
|
The configuration file, `benchmark-config.yaml`, defines the three scenarios:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
"3090-only":
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
cmd: >
|
||||||
|
/mnt/nvme/llama-server/llama-server-f3252055
|
||||||
|
--host 127.0.0.1 --port 9503
|
||||||
|
--flash-attn
|
||||||
|
--slots
|
||||||
|
|
||||||
|
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
-ngl 99
|
||||||
|
--device CUDA0
|
||||||
|
|
||||||
|
--ctx-size 32768
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
|
||||||
|
"3090-with-draft":
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
# --ctx-size 28500 max that can fit on 3090 after draft model
|
||||||
|
cmd: >
|
||||||
|
/mnt/nvme/llama-server/llama-server-f3252055
|
||||||
|
--host 127.0.0.1 --port 9503
|
||||||
|
--flash-attn
|
||||||
|
--slots
|
||||||
|
|
||||||
|
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
-ngl 99
|
||||||
|
--device CUDA0
|
||||||
|
|
||||||
|
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||||
|
-ngld 99
|
||||||
|
--draft-max 16
|
||||||
|
--draft-min 4
|
||||||
|
--draft-p-min 0.4
|
||||||
|
--device-draft CUDA0
|
||||||
|
|
||||||
|
--ctx-size 28500
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
|
||||||
|
"3090-P40-draft":
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
cmd: >
|
||||||
|
/mnt/nvme/llama-server/llama-server-f3252055
|
||||||
|
--host 127.0.0.1 --port 9503
|
||||||
|
--flash-attn --metrics
|
||||||
|
--slots
|
||||||
|
--model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
-ngl 99
|
||||||
|
--device CUDA0
|
||||||
|
|
||||||
|
--model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf
|
||||||
|
-ngld 99
|
||||||
|
--draft-max 16
|
||||||
|
--draft-min 4
|
||||||
|
--draft-p-min 0.4
|
||||||
|
--device-draft CUDA1
|
||||||
|
|
||||||
|
--ctx-size 32768
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
```
|
||||||
|
|
||||||
|
> Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model.
|
||||||
|
|
||||||
|
|
||||||
|
**Running the Benchmark**
|
||||||
|
|
||||||
|
To run the benchmark, execute the following commands:
|
||||||
|
|
||||||
|
1. `llama-swap -config benchmark-config.yaml`
|
||||||
|
1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"`
|
||||||
|
|
||||||
|
The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability.
|
||||||
|
|
||||||
|
**Results (tokens/second)**
|
||||||
|
|
||||||
|
| model | python | typescript | swift |
|
||||||
|
|-----------------|--------|------------|-------|
|
||||||
|
| 3090-only | 34.03 | 34.01 | 34.01 |
|
||||||
|
| 3090-with-draft | 106.65 | 70.48 | 57.89 |
|
||||||
|
| 3090-P40-draft | 81.54 | 60.35 | 46.50 |
|
||||||
|
|
||||||
|
Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware.
|
||||||
|
|
||||||
|
Happy coding!
|
||||||
+40
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift
|
||||||
|
# It was created to test the effects of speculative decoding and the various draft settings on performance.
|
||||||
|
#
|
||||||
|
# Writing code with a low temperature seems to provide fairly consistent logic.
|
||||||
|
#
|
||||||
|
# Usage: ./benchmark.sh <url> <model1> [model2 ...]
|
||||||
|
# Example: ./benchmark.sh http://localhost:8080 model1 model2
|
||||||
|
|
||||||
|
if [ "$#" -lt 2 ]; then
|
||||||
|
echo "Usage: $0 <url> <model1> [model2 ...]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
url=$1; shift
|
||||||
|
|
||||||
|
echo "model,python,typescript,swift"
|
||||||
|
|
||||||
|
for model in "$@"; do
|
||||||
|
|
||||||
|
echo -n "$model,"
|
||||||
|
|
||||||
|
for lang in "python" "typescript" "swift"; do
|
||||||
|
# expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548
|
||||||
|
# (Dec 3rd/2024)
|
||||||
|
time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second)
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
time="error"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$lang" != "swift" ]; then
|
||||||
|
printf "%0.2f tps," $time
|
||||||
|
else
|
||||||
|
printf "%0.2f tps\n" $time
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
@@ -20,6 +20,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaC
|
|||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
|||||||
+5
-1
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/shlex"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,7 +82,10 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
|||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||||
|
|
||||||
// Split the command into arguments
|
// Split the command into arguments
|
||||||
args := strings.Fields(cmdStr)
|
args, err := shlex.Split(cmdStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure the command is not empty
|
// Ensure the command is not empty
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
|
|||||||
+17
-8
@@ -148,17 +148,26 @@ func TestConfig_FindConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
// Test a simple command
|
|
||||||
args, err := SanitizeCommand("python model1.py")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"python", "model1.py"}, args)
|
|
||||||
|
|
||||||
// Test a command with spaces and newlines
|
// Test a command with spaces and newlines
|
||||||
args, err = SanitizeCommand(`python model1.py \
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
--arg1 value1 \
|
-a "double quotes" \
|
||||||
--arg2 value2`)
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
// Test an empty command
|
// Test an empty command
|
||||||
args, err = SanitizeCommand("")
|
args, err = SanitizeCommand("")
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Logs</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
font-family: "Courier New", Courier, monospace;
|
||||||
|
}
|
||||||
|
#log-stream {
|
||||||
|
flex: 1;
|
||||||
|
margin: 1em;
|
||||||
|
padding: 10px;
|
||||||
|
background: #f4f4f4;
|
||||||
|
overflow-y: auto;
|
||||||
|
white-space: pre-wrap; /* Ensures line wrapping */
|
||||||
|
word-wrap: break-word; /* Ensures long words wrap */
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<pre id="log-stream">Waiting for logs...
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// Establish an EventSource connection to the SSE endpoint
|
||||||
|
if (typeof(EventSource) !== "undefined") {
|
||||||
|
const eventSource = new EventSource("/logs/streamSSE");
|
||||||
|
|
||||||
|
eventSource.onmessage = function(event) {
|
||||||
|
// Append the new log message to the <pre> element
|
||||||
|
const logStream = document.getElementById('log-stream');
|
||||||
|
|
||||||
|
logStream.textContent += event.data;
|
||||||
|
|
||||||
|
// Auto-scroll to the bottom
|
||||||
|
logStream.scrollTop = logStream.scrollHeight;
|
||||||
|
};
|
||||||
|
|
||||||
|
eventSource.onerror = function(err) {
|
||||||
|
console.error("EventSource failed:", err);
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
console.error("SSE not supported by this browser.");
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -101,7 +101,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
// issue #19
|
// issue #19
|
||||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping long test")
|
t.Skip("skipping slow test")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedMessage := "12345"
|
expectedMessage := "12345"
|
||||||
|
|||||||
+25
-12
@@ -14,6 +14,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
|
)
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -94,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
// Set the Content-Type header to application/json
|
// Set the Content-Type header to application/json
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||||
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Encode the data as JSON and write it to the response writer
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
||||||
@@ -106,15 +114,15 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a /
|
// Check if requestedModel contains a /
|
||||||
groupName, modelName := "", requestedModel
|
profileName, modelName := "", requestedModel
|
||||||
if idx := strings.Index(requestedModel, "/"); idx != -1 {
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
groupName = requestedModel[:idx]
|
profileName = requestedModel[:idx]
|
||||||
modelName = requestedModel[idx+1:]
|
modelName = requestedModel[idx+1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
if groupName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[groupName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
return nil, fmt.Errorf("model group not found %s", groupName)
|
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +133,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
// exit early when already running, otherwise stop everything and swap
|
||||||
requestedProcessKey := groupName + "/" + realModelName
|
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
return process, nil
|
return process, nil
|
||||||
}
|
}
|
||||||
@@ -133,25 +142,25 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
// stop all running models
|
// stop all running models
|
||||||
pm.stopProcesses()
|
pm.stopProcesses()
|
||||||
|
|
||||||
if groupName == "" {
|
if profileName == "" {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := groupName + "/" + modelID
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
} else {
|
} else {
|
||||||
for _, modelName := range pm.config.Profiles[groupName] {
|
for _, modelName := range pm.config.Profiles[profileName] {
|
||||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
processKey := groupName + "/" + modelID
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -201,3 +210,7 @@ func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
|||||||
|
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,34 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:embed html/logs.html
|
||||||
|
var logsHTML []byte
|
||||||
|
|
||||||
|
// make sure embed is kept there by the IDE auto-package importer
|
||||||
|
var _ = embed.FS{}
|
||||||
|
|
||||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
|
|
||||||
|
accept := c.GetHeader("Accept")
|
||||||
|
if strings.Contains(accept, "text/html") {
|
||||||
|
// Set the Content-Type header to text/html
|
||||||
|
c.Header("Content-Type", "text/html")
|
||||||
|
|
||||||
|
// Write the embedded HTML content to the response
|
||||||
|
_, err := c.Writer.Write(logsHTML)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
history := pm.logMonitor.GetHistory()
|
history := pm.logMonitor.GetHistory()
|
||||||
_, err := c.Writer.Write(history)
|
_, err := c.Writer.Write(history)
|
||||||
@@ -15,6 +36,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
|||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -33,7 +34,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
_, exists := proxy.currentProcesses["/"+modelName]
|
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -43,21 +44,31 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
|
|
||||||
|
model1 := "path1/model1"
|
||||||
|
model2 := "path2/model2"
|
||||||
|
|
||||||
|
profileModel1 := ProcessKeyName("test", model1)
|
||||||
|
profileModel2 := ProcessKeyName("test", model2)
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
model1: getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
model2: getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
"test": {"model1", "model2"},
|
"test": {model1, model2},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
|
for modelID, requestedModel := range map[string]string{
|
||||||
|
"model1": profileModel1,
|
||||||
|
"model2": profileModel2,
|
||||||
|
} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -69,11 +80,11 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
|
|
||||||
// make sure there's two loaded models
|
// make sure there's two loaded models
|
||||||
assert.Len(t, proxy.currentProcesses, 2)
|
assert.Len(t, proxy.currentProcesses, 2)
|
||||||
_, exists := proxy.currentProcesses["test/model1"]
|
_, exists := proxy.currentProcesses[profileModel1]
|
||||||
assert.True(t, exists, "expected test/model1 key in currentProcesses")
|
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||||
|
|
||||||
_, exists = proxy.currentProcesses["test/model2"]
|
_, exists = proxy.currentProcesses[profileModel2]
|
||||||
assert.True(t, exists, "expected test/model2 key in currentProcesses")
|
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a request for a different model comes in ProxyManager should wait until
|
// When a request for a different model comes in ProxyManager should wait until
|
||||||
@@ -131,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
assert.Equal(t, key, result)
|
assert.Equal(t, key, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
|
||||||
|
// Create a test request
|
||||||
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
|
req.Header.Add("Origin", "i-am-the-origin")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the listModelsHandler
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
// Check the response status code
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// Check for Access-Control-Allow-Origin
|
||||||
|
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
||||||
|
|
||||||
|
// Parse the JSON response
|
||||||
|
var response struct {
|
||||||
|
Data []map[string]interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the number of models returned
|
||||||
|
assert.Len(t, response.Data, 3)
|
||||||
|
|
||||||
|
// Check the details of each model
|
||||||
|
expectedModels := map[string]struct{}{
|
||||||
|
"model1": {},
|
||||||
|
"model2": {},
|
||||||
|
"model3": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range response.Data {
|
||||||
|
modelID, ok := model["id"].(string)
|
||||||
|
assert.True(t, ok, "model ID should be a string")
|
||||||
|
_, exists := expectedModels[modelID]
|
||||||
|
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
||||||
|
delete(expectedModels, modelID)
|
||||||
|
|
||||||
|
object, ok := model["object"].(string)
|
||||||
|
assert.True(t, ok, "object should be a string")
|
||||||
|
assert.Equal(t, "model", object)
|
||||||
|
|
||||||
|
created, ok := model["created"].(float64)
|
||||||
|
assert.True(t, ok, "created should be a number")
|
||||||
|
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
||||||
|
|
||||||
|
ownedBy, ok := model["owned_by"].(string)
|
||||||
|
assert.True(t, ok, "owned_by should be a string")
|
||||||
|
assert.Equal(t, "llama-swap", ownedBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure all expected models were returned
|
||||||
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user