Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 | |||
| 9dc4bcb46c | |||
| cb876c143b | |||
| bc652709a5 | |||
| 9548931258 | |||
| 5c5a5da664 | |||
| aa9ef59aa5 | |||
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 |
@@ -0,0 +1,15 @@
|
|||||||
|
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||||
|
language: "en-US"
|
||||||
|
early_access: false
|
||||||
|
reviews:
|
||||||
|
profile: "chill"
|
||||||
|
request_changes_workflow: false
|
||||||
|
high_level_summary: true
|
||||||
|
poem: false
|
||||||
|
review_status: true
|
||||||
|
collapse_walkthrough: false
|
||||||
|
auto_review:
|
||||||
|
enabled: true
|
||||||
|
drafts: false
|
||||||
|
chat:
|
||||||
|
auto_reply: true
|
||||||
@@ -20,10 +20,10 @@ clean:
|
|||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test -short -v ./proxy
|
go test -short -v -count=1 ./proxy
|
||||||
|
|
||||||
test-all:
|
test-all:
|
||||||
go test -v ./proxy
|
go test -v -count=1 ./proxy
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac:
|
mac:
|
||||||
|
|||||||
@@ -70,6 +70,14 @@ healthCheckTimeout: 60
|
|||||||
# Valid log levels: debug, info (default), warn, error
|
# Valid log levels: debug, info (default), warn, error
|
||||||
logLevel: info
|
logLevel: info
|
||||||
|
|
||||||
|
# Automatic Port Values
|
||||||
|
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
|
||||||
|
# when you use ${PORT} you can omit a custom model.proxy value, as it will
|
||||||
|
# default to http://localhost:${PORT}
|
||||||
|
|
||||||
|
# override the default port (5800) for automatic port values
|
||||||
|
startPort: 10001
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
@@ -83,6 +91,7 @@ models:
|
|||||||
- "CUDA_VISIBLE_DEVICES=0"
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
|
# can be omitted if you use an automatic ${PORT} in cmd
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases names to use this model for
|
# aliases names to use this model for
|
||||||
@@ -109,14 +118,14 @@ models:
|
|||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
unlisted: true
|
unlisted: true
|
||||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# Docker Support (v26.1.4+ required!)
|
||||||
"docker-llama":
|
"docker-llama":
|
||||||
proxy: "http://127.0.0.1:9790"
|
proxy: "http://127.0.0.1:${PORT}"
|
||||||
cmd: >
|
cmd: >
|
||||||
docker run --name dockertest
|
docker run --name dockertest
|
||||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
@@ -180,18 +189,13 @@ groups:
|
|||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
llama-s
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
|
|
||||||
Docker is the quickest way to try out llama-swap:
|
Docker is the quickest way to try out llama-swap:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
# use CPU inference
|
# use CPU inference
|
||||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||||
|
|
||||||
@@ -227,7 +231,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
|
|||||||
|
|
||||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
-v /path/to/models:/models \
|
-v /path/to/models:/models \
|
||||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
@@ -242,7 +246,12 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
|
|||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
1. Run the binary with `llama-swap --config path/to/config.yaml`.
|
||||||
|
Available flags:
|
||||||
|
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||||
|
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||||
|
- `--version`: Show version information and exit.
|
||||||
|
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
||||||
|
|
||||||
### Building from source
|
### Building from source
|
||||||
|
|
||||||
@@ -257,7 +266,7 @@ Open the `http://<host>/logs` with your browser to get a web interface with stre
|
|||||||
|
|
||||||
Of course, CLI access is also supported:
|
Of course, CLI access is also supported:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
curl http://host/logs'
|
||||||
|
|
||||||
|
|||||||
+16
-20
@@ -5,13 +5,20 @@ healthCheckTimeout: 90
|
|||||||
# valid log levels: debug, info (default), warn, error
|
# valid log levels: debug, info (default), warn, error
|
||||||
logLevel: debug
|
logLevel: debug
|
||||||
|
|
||||||
|
# creating a coding profile with models for code generation and general questions
|
||||||
|
groups:
|
||||||
|
coding:
|
||||||
|
swap: false
|
||||||
|
members:
|
||||||
|
- "qwen"
|
||||||
|
- "llama"
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
cmd: >
|
||||||
models/llama-server-osx
|
models/llama-server-osx
|
||||||
--port 9001
|
--port ${PORT}
|
||||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||||
proxy: http://127.0.0.1:9001
|
|
||||||
|
|
||||||
# list of model name aliases this llama.cpp instance can serve
|
# list of model name aliases this llama.cpp instance can serve
|
||||||
aliases:
|
aliases:
|
||||||
@@ -24,17 +31,15 @@ models:
|
|||||||
ttl: 5
|
ttl: 5
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
proxy: http://127.0.0.1:9002
|
|
||||||
aliases:
|
aliases:
|
||||||
- gpt-3.5-turbo
|
- gpt-3.5-turbo
|
||||||
|
|
||||||
# Embedding example with Nomic
|
# Embedding example with Nomic
|
||||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||||
"nomic":
|
"nomic":
|
||||||
proxy: http://127.0.0.1:9005
|
|
||||||
cmd: >
|
cmd: >
|
||||||
models/llama-server-osx --port 9005
|
models/llama-server-osx --port ${PORT}
|
||||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||||
--ctx-size 8192
|
--ctx-size 8192
|
||||||
--batch-size 8192
|
--batch-size 8192
|
||||||
@@ -46,19 +51,17 @@ models:
|
|||||||
# Reranking example with bge-reranker
|
# Reranking example with bge-reranker
|
||||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||||
"bge-reranker":
|
"bge-reranker":
|
||||||
proxy: http://127.0.0.1:9006
|
|
||||||
cmd: >
|
cmd: >
|
||||||
models/llama-server-osx --port 9006
|
models/llama-server-osx --port ${PORT}
|
||||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||||
--ctx-size 8192
|
--ctx-size 8192
|
||||||
--reranking
|
--reranking
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# Docker Support (v26.1.4+ required!)
|
||||||
"dockertest":
|
"dockertest":
|
||||||
proxy: "http://127.0.0.1:9790"
|
|
||||||
cmd: >
|
cmd: >
|
||||||
docker run --name dockertest
|
docker run --name dockertest
|
||||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
@@ -67,8 +70,7 @@ models:
|
|||||||
env:
|
env:
|
||||||
- CUDA_VISIBLE_DEVICES=0,1
|
- CUDA_VISIBLE_DEVICES=0,1
|
||||||
- env1=hello
|
- env1=hello
|
||||||
cmd: build/simple-responder --port 8999
|
cmd: build/simple-responder --port ${PORT}
|
||||||
proxy: http://127.0.0.1:8999
|
|
||||||
unlisted: true
|
unlisted: true
|
||||||
|
|
||||||
# use "none" to skip check. Caution this may cause some requests to fail
|
# use "none" to skip check. Caution this may cause some requests to fail
|
||||||
@@ -83,10 +85,4 @@ models:
|
|||||||
"broken_timeout":
|
"broken_timeout":
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
proxy: http://127.0.0.1:9000
|
proxy: http://127.0.0.1:9000
|
||||||
unlisted: true
|
unlisted: true
|
||||||
|
|
||||||
# creating a coding profile with models for code generation and general questions
|
|
||||||
profiles:
|
|
||||||
coding:
|
|
||||||
- "qwen"
|
|
||||||
- "llama"
|
|
||||||
@@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap
|
|||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
@@ -23,6 +27,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
|||||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||||
@@ -74,34 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
|||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
|
||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
|
||||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
|
||||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
|
||||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
|
||||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|||||||
+137
-11
@@ -1,25 +1,34 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
var version string = "0"
|
var (
|
||||||
var commit string = "abcd1234"
|
version string = "0"
|
||||||
var date = "unknown"
|
commit string = "abcd1234"
|
||||||
|
date string = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
configPath := flag.String("config", "config.yaml", "config file name")
|
configPath := flag.String("config", "config.yaml", "config file name")
|
||||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||||
showVersion := flag.Bool("version", false, "show version of build")
|
showVersion := flag.Bool("version", false, "show version of build")
|
||||||
|
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
@@ -46,18 +55,135 @@ func main() {
|
|||||||
|
|
||||||
proxyManager := proxy.New(config)
|
proxyManager := proxy.New(config)
|
||||||
|
|
||||||
|
// Setup channels for server management
|
||||||
|
reloadChan := make(chan *proxy.ProxyManager)
|
||||||
|
exitChan := make(chan struct{})
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Create server with initial handler
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: *listenStr,
|
||||||
|
Handler: proxyManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||||
go func() {
|
go func() {
|
||||||
<-sigChan
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
fmt.Println("Shutting down llama-swap")
|
fmt.Printf("Fatal server error: %v\n", err)
|
||||||
proxyManager.Shutdown()
|
close(exitChan)
|
||||||
os.Exit(0)
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
fmt.Println("llama-swap listening on " + *listenStr)
|
// Handle config reloads and signals
|
||||||
if err := proxyManager.Run(*listenStr); err != nil {
|
go func() {
|
||||||
fmt.Printf("Server error: %v\n", err)
|
currentManager := proxyManager
|
||||||
os.Exit(1)
|
for {
|
||||||
|
select {
|
||||||
|
case newManager := <-reloadChan:
|
||||||
|
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||||
|
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||||
|
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||||
|
// Now do a full shutdown to clear the process map
|
||||||
|
currentManager.Shutdown()
|
||||||
|
currentManager = newManager
|
||||||
|
srv.Handler = newManager
|
||||||
|
log.Println("Server handler updated with new config")
|
||||||
|
case sig := <-sigChan:
|
||||||
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
currentManager.Shutdown()
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
fmt.Printf("Server shutdown error: %v\n", err)
|
||||||
|
}
|
||||||
|
close(exitChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start file watcher if requested
|
||||||
|
if *watchConfig {
|
||||||
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
||||||
|
} else {
|
||||||
|
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for exit signal
|
||||||
|
<-exitChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
||||||
|
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer watcher.Close()
|
||||||
|
|
||||||
|
err = watcher.Add(configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Watching config file for changes: %s", configPath)
|
||||||
|
|
||||||
|
var debounceTimer *time.Timer
|
||||||
|
debounceDuration := 2 * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// We only care about writes to the specific config file
|
||||||
|
if event.Name == configPath && event.Has(fsnotify.Write) {
|
||||||
|
// Reset or start the debounce timer
|
||||||
|
if debounceTimer != nil {
|
||||||
|
debounceTimer.Stop()
|
||||||
|
}
|
||||||
|
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
||||||
|
log.Printf("Config file modified: %s, reloading...", event.Name)
|
||||||
|
|
||||||
|
// Try up to 3 times with exponential backoff
|
||||||
|
var newConfig proxy.Config
|
||||||
|
var err error
|
||||||
|
for retries := 0; retries < 3; retries++ {
|
||||||
|
// Load new configuration
|
||||||
|
newConfig, err = proxy.LoadConfig(configPath)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
||||||
|
if retries < 2 {
|
||||||
|
time.Sleep(time.Duration(1<<retries) * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to load new config after retries: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new ProxyManager with new config
|
||||||
|
newPM := proxy.New(newConfig)
|
||||||
|
reloadChan <- newPM
|
||||||
|
log.Println("Config reloaded successfully")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case err, ok := <-watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
log.Println("File watcher error channel closed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("File watcher error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ func main() {
|
|||||||
|
|
||||||
silent := flag.Bool("silent", false, "disable all logging")
|
silent := flag.Bool("silent", false, "disable all logging")
|
||||||
|
|
||||||
|
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
// Create a new Gin router
|
// Create a new Gin router
|
||||||
@@ -33,14 +35,17 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// add a wait to simulate a slow query
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
time.Sleep(wait)
|
time.Sleep(wait)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -63,8 +68,11 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
@@ -104,6 +112,10 @@ func main() {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
||||||
|
// expose some header values for testing
|
||||||
|
"h_content_type": c.GetHeader("Content-Type"),
|
||||||
|
"h_content_length": c.GetHeader("Content-Length"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -180,6 +192,10 @@ func main() {
|
|||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !*silent {
|
||||||
|
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("simple-responder listening on %s\n", address)
|
log.Printf("simple-responder listening on %s\n", address)
|
||||||
// service connections
|
// service connections
|
||||||
@@ -190,11 +206,36 @@ func main() {
|
|||||||
|
|
||||||
// Wait for interrupt signal to gracefully shutdown the server with
|
// Wait for interrupt signal to gracefully shutdown the server with
|
||||||
// a timeout of 5 seconds.
|
// a timeout of 5 seconds.
|
||||||
quit := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
// kill (no param) default send syscall.SIGTERM
|
// kill (no param) default send syscall.SIGTERM
|
||||||
// kill -2 is syscall.SIGINT
|
// kill -2 is syscall.SIGINT
|
||||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
<-quit
|
|
||||||
|
countSigInt := 0
|
||||||
|
|
||||||
|
runloop:
|
||||||
|
for {
|
||||||
|
signal := <-sigChan
|
||||||
|
switch signal {
|
||||||
|
case syscall.SIGINT:
|
||||||
|
countSigInt++
|
||||||
|
if countSigInt > 1 {
|
||||||
|
break runloop
|
||||||
|
} else {
|
||||||
|
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
|
||||||
|
}
|
||||||
|
case syscall.SIGTERM:
|
||||||
|
if *ignoreSigTerm {
|
||||||
|
log.Println("Ignoring SIGTERM")
|
||||||
|
} else {
|
||||||
|
log.Println("Recieved SIGTERM, shutting down")
|
||||||
|
break runloop
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break runloop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("simple-responder shutting down")
|
log.Println("simple-responder shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
+54
-1
@@ -2,8 +2,10 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/shlex"
|
"github.com/google/shlex"
|
||||||
@@ -21,6 +23,9 @@ type ModelConfig struct {
|
|||||||
UnloadAfter int `yaml:"ttl"`
|
UnloadAfter int `yaml:"ttl"`
|
||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
UseModelName string `yaml:"useModelName"`
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// Limit concurrency of HTTP requests to process
|
||||||
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
@@ -62,6 +67,9 @@ type Config struct {
|
|||||||
|
|
||||||
// map aliases to actual model IDs
|
// map aliases to actual model IDs
|
||||||
aliases map[string]string
|
aliases map[string]string
|
||||||
|
|
||||||
|
// automatic port assignments
|
||||||
|
StartPort int `yaml:"startPort"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -83,7 +91,16 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (Config, error) {
|
func LoadConfig(path string) (Config, error) {
|
||||||
data, err := os.ReadFile(path)
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
return LoadConfigFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
@@ -98,14 +115,50 @@ func LoadConfig(path string) (Config, error) {
|
|||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set default port ranges
|
||||||
|
if config.StartPort == 0 {
|
||||||
|
// default to 5800
|
||||||
|
config.StartPort = 5800
|
||||||
|
} else if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
// Populate the aliases map
|
// Populate the aliases map
|
||||||
config.aliases = make(map[string]string)
|
config.aliases = make(map[string]string)
|
||||||
for modelName, modelConfig := range config.Models {
|
for modelName, modelConfig := range config.Models {
|
||||||
for _, alias := range modelConfig.Aliases {
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
config.aliases[alias] = modelName
|
config.aliases[alias] = modelName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// iterate over the models and replace any ${PORT} with the next available port
|
||||||
|
// Get and sort all model IDs first, makes testing more consistent
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||||
|
|
||||||
|
// iterate over the sorted models
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
||||||
|
if modelConfig.Proxy == "" {
|
||||||
|
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
|
||||||
|
} else {
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
|
||||||
|
}
|
||||||
|
nextPort++
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
} else if modelConfig.Proxy == "" {
|
||||||
|
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
config = AddDefaultGroupToConfig(config)
|
config = AddDefaultGroupToConfig(config)
|
||||||
// check that members are all unique in the groups
|
// check that members are all unique in the groups
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
|||||||
+105
-15
@@ -3,6 +3,7 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -43,6 +44,7 @@ models:
|
|||||||
checkEndpoint: "/"
|
checkEndpoint: "/"
|
||||||
model4:
|
model4:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
checkEndpoint: "/"
|
checkEndpoint: "/"
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 15
|
||||||
@@ -73,6 +75,7 @@ groups:
|
|||||||
}
|
}
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
|
StartPort: 5800,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -97,6 +100,7 @@ groups:
|
|||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -138,14 +142,6 @@ groups:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||||
// Create a temporary YAML file for testing
|
|
||||||
tempDir, err := os.MkdirTemp("", "test-config")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
|
|
||||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
|
||||||
content := `
|
content := `
|
||||||
models:
|
models:
|
||||||
model1:
|
model1:
|
||||||
@@ -171,15 +167,35 @@ groups:
|
|||||||
exclusive: false
|
exclusive: false
|
||||||
members: ["model2"]
|
members: ["model2"]
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write temporary file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the config and verify
|
// Load the config and verify
|
||||||
_, err = LoadConfig(tempFile)
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
assert.NotNil(t, err)
|
|
||||||
|
|
||||||
|
// a Contains as order of the map is not guaranteed
|
||||||
|
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
aliases:
|
||||||
|
- m1
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
aliases:
|
||||||
|
- m1
|
||||||
|
- m2
|
||||||
|
`
|
||||||
|
// Load the config and verify
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
|
||||||
|
// this is a contains because it could be `model1` or `model2` depending on the order
|
||||||
|
// go decided on the order of the map
|
||||||
|
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
@@ -269,3 +285,77 @@ func TestConfig_SanitizeCommand(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, args)
|
assert.Nil(t, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||||
|
content := ``
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
})
|
||||||
|
t.Run("User specific port ranges", func(t *testing.T) {
|
||||||
|
content := `startPort: 1000`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1000, config.StartPort)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid start port", func(t *testing.T) {
|
||||||
|
content := `startPort: abcd`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start port must be greater than 1", func(t *testing.T) {
|
||||||
|
content := `startPort: -99`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Automatic port assignments", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 5800
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
model2:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
proxy: "http://172.11.22.33:${PORT}"
|
||||||
|
model3:
|
||||||
|
cmd: svr --port 1999
|
||||||
|
proxy: "http://1.2.3.4:1999"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
||||||
|
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
||||||
|
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
||||||
|
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port 111
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -48,14 +48,18 @@ func getSimpleResponderPath() string {
|
|||||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
func getTestPort() int {
|
||||||
portMutex.Lock()
|
portMutex.Lock()
|
||||||
defer portMutex.Unlock()
|
defer portMutex.Unlock()
|
||||||
|
|
||||||
port := nextTestPort
|
port := nextTestPort
|
||||||
nextTestPort++
|
nextTestPort++
|
||||||
|
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||||
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||||
|
|||||||
+82
-13
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -29,6 +30,13 @@ const (
|
|||||||
StateShutdown ProcessState = ProcessState("shutdown")
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StopStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StopImmediately StopStrategy = iota
|
||||||
|
StopWaitForInflightRequest
|
||||||
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
@@ -56,10 +64,25 @@ type Process struct {
|
|||||||
// for managing shutdown state
|
// for managing shutdown state
|
||||||
shutdownCtx context.Context
|
shutdownCtx context.Context
|
||||||
shutdownCancel context.CancelFunc
|
shutdownCancel context.CancelFunc
|
||||||
|
|
||||||
|
// for managing concurrency limits
|
||||||
|
concurrencyLimitSemaphore chan struct{}
|
||||||
|
|
||||||
|
// stop timeout waiting for graceful shutdown
|
||||||
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
|
// track that this happened
|
||||||
|
upstreamWasStoppedWithKill bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
concurrentLimit := 10
|
||||||
|
if config.ConcurrencyLimit > 0 {
|
||||||
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
|
} else {
|
||||||
|
proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID)
|
||||||
|
}
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
@@ -72,6 +95,13 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
|
|||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
shutdownCtx: ctx,
|
shutdownCtx: ctx,
|
||||||
shutdownCancel: cancel,
|
shutdownCancel: cancel,
|
||||||
|
|
||||||
|
// concurrency limit
|
||||||
|
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||||
|
|
||||||
|
// stop timeout
|
||||||
|
gracefulStopTimeout: 5 * time.Second,
|
||||||
|
upstreamWasStoppedWithKill: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,10 +214,19 @@ func (p *Process) start() error {
|
|||||||
return fmt.Errorf("start() failed: %v", err)
|
return fmt.Errorf("start() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture the exit error for later signaling
|
// Capture the exit error for later signalling
|
||||||
go func() {
|
go func() {
|
||||||
exitErr := p.cmd.Wait()
|
exitErr := p.cmd.Wait()
|
||||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||||
|
|
||||||
|
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
|
||||||
|
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
|
||||||
|
if p.upstreamWasStoppedWithKill {
|
||||||
|
p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr)
|
||||||
|
p.upstreamWasStoppedWithKill = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.cmdWaitChan <- exitErr
|
p.cmdWaitChan <- exitErr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -259,9 +298,9 @@ func (p *Process) start() error {
|
|||||||
if strings.Contains(err.Error(), "connection refused") {
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
endTime, _ := checkDeadline.Deadline()
|
endTime, _ := checkDeadline.Deadline()
|
||||||
ttl := time.Until(endTime)
|
ttl := time.Until(endTime)
|
||||||
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds())
|
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err)
|
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,13 +339,25 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop will wait for inflight requests to complete before stopping the process.
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
|
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
p.StopImmediately()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||||
|
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||||
|
func (p *Process) StopImmediately() {
|
||||||
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
// calling Stop() when state is invalid is a no-op
|
||||||
@@ -316,7 +367,7 @@ func (p *Process) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
// stop the process with a graceful exit timeout
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(p.gracefulStopTimeout)
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
||||||
@@ -325,10 +376,11 @@ func (p *Process) Stop() {
|
|||||||
|
|
||||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
// of time for any inflight requests to complete before shutting down. If the Process
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
// is in the state of starting, it will cancel it and shut it down
|
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||||
|
// the StateShutdown state, it can not be started again.
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
p.shutdownCancel()
|
p.shutdownCancel()
|
||||||
p.stopCommand(5 * time.Second)
|
p.stopCommand(p.gracefulStopTimeout)
|
||||||
p.state = StateShutdown
|
p.state = StateShutdown
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -344,31 +396,34 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
defer cancelTimeout()
|
defer cancelTimeout()
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID)
|
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
if err := p.terminateProcess(); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err)
|
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID)
|
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
|
||||||
p.cmd.Process.Kill()
|
p.upstreamWasStoppedWithKill = true
|
||||||
|
if err := p.cmd.Process.Kill(); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
|
||||||
|
}
|
||||||
case err := <-p.cmdWaitChan:
|
case err := <-p.cmdWaitChan:
|
||||||
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
||||||
// because if we make it here then the cmd has been successfully running and made it
|
// because if we make it here then the cmd has been successfully running and made it
|
||||||
// through the health check. There is a possibility that ithe cmd crashed after the health check
|
// through the health check. There is a possibility that the cmd crashed after the health check
|
||||||
// succeeded but that's not a case llama-swap is handling for now.
|
// succeeded but that's not a case llama-swap is handling for now.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
if errno, ok := err.(syscall.Errno); ok {
|
||||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID)
|
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID)
|
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||||
} else {
|
} else {
|
||||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||||
}
|
}
|
||||||
@@ -414,6 +469,14 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||||
|
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||||
|
default:
|
||||||
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.lastRequestHandled = time.Now()
|
||||||
@@ -439,6 +502,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header = r.Header.Clone()
|
req.Header = r.Header.Clone()
|
||||||
|
|
||||||
|
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
req.ContentLength = contentLength
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
|||||||
@@ -340,3 +340,106 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||||
assert.Equal(t, process.CurrentState(), StateFailed)
|
assert.Equal(t, process.CurrentState(), StateFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long concurrency limit test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "concurrency_limit_test"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// only allow 1 concurrent request at a time
|
||||||
|
config.ConcurrencyLimit = 1
|
||||||
|
|
||||||
|
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||||
|
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// launch a goroutine first to take up the semaphore
|
||||||
|
go func() {
|
||||||
|
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req1)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// let the goroutine start
|
||||||
|
<-time.After(time.Millisecond * 25)
|
||||||
|
|
||||||
|
denied := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, denied)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopImmediately(t *testing.T) {
|
||||||
|
expectedMessage := "test_stop_immediate"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
}()
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
|
// the upstream command
|
||||||
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
|
||||||
|
expectedMessage := "test_sigkill"
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
port := getTestPort()
|
||||||
|
|
||||||
|
config := ModelConfig{
|
||||||
|
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||||
|
// to force the process to exit
|
||||||
|
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// reduce to make testing go faster
|
||||||
|
process.gracefulStopTimeout = time.Second
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
|
||||||
|
waitChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
|
// StatusOK because that was already sent before the kill
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||||
|
// then the unexpected EOF is sent after the kill
|
||||||
|
assert.Equal(t, "1unexpected EOF\n", w.Body.String())
|
||||||
|
close(waitChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
|
||||||
|
// the request should have been interrupted by SIGKILL
|
||||||
|
<-waitChan
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcesses() {
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
pg.stopProcesses()
|
|
||||||
}
|
|
||||||
|
|
||||||
// stopProcesses stops all processes in the group
|
|
||||||
func (pg *ProcessGroup) stopProcesses() {
|
|
||||||
if len(pg.processes) == 0 {
|
if len(pg.processes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(process *Process) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Stop()
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
}(process)
|
}(process)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses()
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2"}
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses()
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model3", "model4"}
|
tests := []string{"model3", "model4"}
|
||||||
|
|
||||||
|
|||||||
+31
-18
@@ -82,6 +82,11 @@ func New(config Config) *ProxyManager {
|
|||||||
pm.processGroups[groupID] = processGroup
|
pm.processGroups[groupID] = processGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm.setupGinEngine()
|
||||||
|
return pm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) setupGinEngine() {
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
// Start timer
|
// Start timer
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@@ -192,19 +197,18 @@ func New(config Config) *ProxyManager {
|
|||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
gin.DisableConsoleColor()
|
gin.DisableConsoleColor()
|
||||||
|
|
||||||
return pm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) Run(addr ...string) error {
|
// ServeHTTP implements http.Handler interface
|
||||||
return pm.ginEngine.Run(addr...)
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
|
||||||
pm.ginEngine.ServeHTTP(w, r)
|
pm.ginEngine.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) StopProcesses() {
|
// StopProcesses acquires a lock and stops all running upstream processes.
|
||||||
|
// This is the public method safe for concurrent calls.
|
||||||
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
||||||
|
// a complete shutdown, allowing for process replacement without full termination.
|
||||||
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
@@ -214,15 +218,14 @@ func (pm *ProxyManager) StopProcesses() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(processGroup *ProcessGroup) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
processGroup.stopProcesses()
|
processGroup.StopProcesses(strategy)
|
||||||
}(processGroup)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called to shutdown all upstream processes
|
// Shutdown stops all processes managed by this ProxyManager
|
||||||
// when llama-swap is shutting down.
|
|
||||||
func (pm *ProxyManager) Shutdown() {
|
func (pm *ProxyManager) Shutdown() {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
@@ -257,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||||
for groupId, otherGroup := range pm.processGroups {
|
for groupId, otherGroup := range pm.processGroups {
|
||||||
if groupId != processGroup.id && !otherGroup.persistent {
|
if groupId != processGroup.id && !otherGroup.persistent {
|
||||||
otherGroup.StopProcesses()
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -305,6 +308,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
@@ -347,11 +351,13 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
@@ -373,15 +379,11 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Parse multipart form
|
// Parse multipart form
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
@@ -398,8 +400,14 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
// Copy all form values
|
// Copy all form values
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
@@ -473,10 +481,15 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.Header = c.Request.Header.Clone()
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// set the content length of the body
|
||||||
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||||
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,7 +504,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses()
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+55
-30
@@ -8,6 +8,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -26,14 +27,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
}
|
}
|
||||||
@@ -62,7 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2"}
|
||||||
for _, requestedModel := range tests {
|
for _, requestedModel := range tests {
|
||||||
@@ -71,10 +72,9 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), requestedModel)
|
assert.Contains(t, w.Body.String(), requestedModel)
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +105,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
// make requests to load all models, loading model1 should not affect model2
|
// make requests to load all models, loading model1 should not affect model2
|
||||||
tests := []string{"model2", "model1"}
|
tests := []string{"model2", "model1"}
|
||||||
@@ -114,7 +114,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), requestedModel)
|
assert.Contains(t, w.Body.String(), requestedModel)
|
||||||
}
|
}
|
||||||
@@ -141,7 +141,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
results := map[string]string{}
|
results := map[string]string{}
|
||||||
|
|
||||||
@@ -157,15 +157,16 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
var response map[string]string
|
||||||
results[key] = w.Body.String()
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
results[key] = response["responseMessage"]
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}(key)
|
}(key)
|
||||||
|
|
||||||
@@ -199,7 +200,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Call the listModelsHandler
|
// Call the listModelsHandler
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
// Check the response status code
|
// Check the response status code
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -289,7 +290,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||||
}(modelName)
|
}(modelName)
|
||||||
@@ -315,12 +316,12 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||||
req = httptest.NewRequest("GET", "/unload", nil)
|
req = httptest.NewRequest("GET", "/unload", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
|
||||||
@@ -331,7 +332,6 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
|
|
||||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||||
|
|
||||||
// Shared configuration
|
// Shared configuration
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := AddDefaultGroupToConfig(Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
@@ -339,7 +339,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "debug",
|
LogLevel: "warn",
|
||||||
})
|
})
|
||||||
|
|
||||||
// Define a helper struct to parse the JSON response.
|
// Define a helper struct to parse the JSON response.
|
||||||
@@ -352,12 +352,12 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
// Create proxy once for all tests
|
// Create proxy once for all tests
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
t.Run("no models loaded", func(t *testing.T) {
|
t.Run("no models loaded", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
@@ -375,13 +375,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
reqBody := `{"model":"model1"}`
|
reqBody := `{"model":"model1"}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
// Simulate browser call for the `/running` endpoint.
|
// Simulate browser call for the `/running` endpoint.
|
||||||
req = httptest.NewRequest("GET", "/running", nil)
|
req = httptest.NewRequest("GET", "/running", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
var response RunningResponse
|
var response RunningResponse
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
@@ -407,7 +407,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
// Create a buffer with multipart form data
|
// Create a buffer with multipart form data
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
@@ -433,7 +433,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Verify the response
|
// Verify the response
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
@@ -442,6 +442,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "TheExpectedModel", response["model"])
|
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||||
|
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
@@ -460,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
requestedModel := "model1"
|
requestedModel := "model1"
|
||||||
|
|
||||||
@@ -469,7 +470,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
proxy.HandlerFunc(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||||
})
|
})
|
||||||
@@ -496,7 +497,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Verify the response
|
// Verify the response
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
@@ -556,7 +557,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||||
for k, v := range tt.requestHeaders {
|
for k, v := range tt.requestHeaders {
|
||||||
@@ -564,7 +565,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.ginEngine.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||||
|
|
||||||
@@ -585,10 +586,34 @@ func TestProxyManager_Upstream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses()
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
proxy.HandlerFunc(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
var response map[string]string
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
assert.Equal(t, "81", response["h_content_length"])
|
||||||
|
assert.Equal(t, "model1", response["responseMessage"])
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user