Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 73ad85ea69 | |||
| 533162ce6a | |||
| ba39ed4c18 | |||
| 21f54f96c2 | |||
| 7eec51f3f2 | |||
| 5021e0f299 | |||
| c9233d2c9a | |||
| a33ac6f8fb | |||
| 401aa88949 | |||
| e9e88fd229 | |||
| c3b4bb1684 | |||
| e5c909ddf7 | |||
| 36a31f450f | |||
| a8e5ee13b9 | |||
| 5944a86e86 | |||
| 63d4a7d0eb | |||
| f45469f7ff | |||
| 34f9fd7340 | |||
| 8448efa7fc | |||
| 8cf2a389d8 | |||
| 0f133f5b74 | |||
| 1510b3fbd9 | |||
| 0f8a8e70f1 | |||
| 6c3819022c |
@@ -9,6 +9,12 @@ all: mac linux simple-responder
|
|||||||
clean:
|
clean:
|
||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test -short -v ./proxy
|
||||||
|
|
||||||
|
test-all:
|
||||||
|
go test -v ./proxy
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac:
|
mac:
|
||||||
@echo "Building Mac binary..."
|
@echo "Building Mac binary..."
|
||||||
@@ -19,10 +25,11 @@ linux:
|
|||||||
@echo "Building Linux binary..."
|
@echo "Building Linux binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
|
||||||
# for testing things
|
# for testing proxy.Process
|
||||||
simple-responder:
|
simple-responder:
|
||||||
@echo "Building simple responder"
|
@echo "Building simple responder"
|
||||||
go build -o $(BUILD_DIR)/simple-responder misc/simple-responder/simple-responder.go
|
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||||
|
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||||
|
|
||||||
# Ensure build directory exists
|
# Ensure build directory exists
|
||||||
$(BUILD_DIR):
|
$(BUILD_DIR):
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, so let's swap llama-server instead!
|

|
||||||
|
|
||||||
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
|
llama-swap is a golang server that automatically swaps the llama.cpp server on demand. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||||
|
|
||||||
- ✅ easy to deploy: single binary with no dependencies
|
Features:
|
||||||
- ✅ full control over llama-server's startup settings
|
|
||||||
- ✅ ❤️ for nvidia P40 users who are rely on llama.cpp for inference
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
|
- ✅ Single yaml configuration file
|
||||||
|
- ✅ Automatically switching between models
|
||||||
|
- ✅ Full control over llama.cpp server settings per model
|
||||||
|
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||||
|
- ✅ Multiple GPU support
|
||||||
|
- ✅ Run multiple models at once with `profiles`
|
||||||
|
- ✅ Remote log monitoring at `/log`
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration purposefully simple.
|
llama-swap's configuration is purposefully simple.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||||
@@ -20,29 +27,48 @@ healthCheckTimeout: 60
|
|||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
cmd: "llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf"
|
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
# Where to proxy to, important it matches this format
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
proxy: "http://127.0.0.1:8999"
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases model names to use this configuration for
|
# aliases names to use this model for
|
||||||
aliases:
|
aliases:
|
||||||
- "gpt-4o-mini"
|
- "gpt-4o-mini"
|
||||||
- "gpt-3.5-turbo"
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
# wait for this path to return an HTTP 200 before serving requests
|
# check this path for an HTTP 200 OK before serving requests
|
||||||
# defaults to /health to match llama.cpp
|
# default: /health to match llama.cpp
|
||||||
#
|
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||||
# use "none" to skip endpoint checking. This may cause requests to fail
|
# until the model is ready
|
||||||
# until the server is ready
|
checkEndpoint: /custom-endpoint
|
||||||
checkEndpoint: "/custom-endpoint"
|
|
||||||
|
# automatically unload the model after this many seconds
|
||||||
|
# ttl values must be a value greater than 0
|
||||||
|
# default: 0 = never unload model
|
||||||
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
# environment variables to pass to the command
|
# environment variables to pass to the command
|
||||||
env:
|
env:
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
cmd: "llama-server --port 8999 -m path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf"
|
|
||||||
proxy: "http://127.0.0.1:8999"
|
# multiline for readability
|
||||||
|
cmd: >
|
||||||
|
llama-server --port 8999
|
||||||
|
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||||
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
|
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||||
|
#
|
||||||
|
# Tips:
|
||||||
|
# - each model must be listening on a unique address and port
|
||||||
|
# - the model name is in this format: "profile_name/model", like "coding/qwen"
|
||||||
|
# - the profile will load and unload all models in the profile at the same time
|
||||||
|
profiles:
|
||||||
|
coding:
|
||||||
|
- "qwen"
|
||||||
|
- "llama"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -52,6 +78,26 @@ models:
|
|||||||
* _Note: Windows currently untested._
|
* _Note: Windows currently untested._
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||||
|
|
||||||
|
## Monitoring Logs
|
||||||
|
|
||||||
|
The `/logs` endpoint is available to monitor what llama-swap is doing. It will send the last 10KB of logs. Useful for monitoring the output of llama-server. It also supports streaming of logs.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
# sends up to the last 10KB of logs
|
||||||
|
curl http://host/logs'
|
||||||
|
|
||||||
|
# streams logs using chunk encoding
|
||||||
|
curl -Ns 'http://host/logs/stream'
|
||||||
|
|
||||||
|
# skips history and just streams new log entries
|
||||||
|
curl -Ns 'http://host/logs/stream?no-history'
|
||||||
|
|
||||||
|
# streams logs using Server Sent Events
|
||||||
|
curl -Ns 'http://host/logs/streamSSE'
|
||||||
|
```
|
||||||
|
|
||||||
## Systemd Unit Files
|
## Systemd Unit Files
|
||||||
|
|
||||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||||
|
|||||||
+20
-8
@@ -1,14 +1,14 @@
|
|||||||
# Seconds to wait for llama.cpp to be available to serve requests
|
# Seconds to wait for llama.cpp to be available to serve requests
|
||||||
# Default (and minimum): 15 seconds
|
# Default (and minimum): 15 seconds
|
||||||
healthCheckTimeout: 60
|
healthCheckTimeout: 15
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
cmd: >
|
||||||
models/llama-server-osx
|
models/llama-server-osx
|
||||||
--port 8999
|
--port 9001
|
||||||
-m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:9001
|
||||||
|
|
||||||
# list of model name aliases this llama.cpp instance can serve
|
# list of model name aliases this llama.cpp instance can serve
|
||||||
aliases:
|
aliases:
|
||||||
@@ -17,9 +17,12 @@ models:
|
|||||||
# check this path for a HTTP 200 response for the server to be ready
|
# check this path for a HTTP 200 response for the server to be ready
|
||||||
checkEndpoint: /health
|
checkEndpoint: /health
|
||||||
|
|
||||||
|
# unload model after 5 seconds
|
||||||
|
ttl: 5
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:9002
|
||||||
aliases:
|
aliases:
|
||||||
- gpt-3.5-turbo
|
- gpt-3.5-turbo
|
||||||
|
|
||||||
@@ -35,7 +38,16 @@ models:
|
|||||||
# until the upstream server is ready for traffic
|
# until the upstream server is ready for traffic
|
||||||
checkEndpoint: none
|
checkEndpoint: none
|
||||||
|
|
||||||
# don't use this, just for testing if things are broken
|
# don't use these, just for testing if things are broken
|
||||||
"broken":
|
"broken":
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
"broken_timeout":
|
||||||
|
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
|
proxy: http://127.0.0.1:9000
|
||||||
|
|
||||||
|
# creating a coding profile with models for code generation and general questions
|
||||||
|
profiles:
|
||||||
|
coding:
|
||||||
|
- "qwen"
|
||||||
|
- "llama"
|
||||||
@@ -8,6 +8,32 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
|
golang.org/x/crypto v0.23.0 // indirect
|
||||||
|
golang.org/x/net v0.25.0 // indirect
|
||||||
|
golang.org/x/sys v0.20.0 // indirect
|
||||||
|
golang.org/x/text v0.15.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,83 @@
|
|||||||
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
|
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||||
|
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||||
|
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||||
|
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||||
|
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
|
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||||
|
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||||
|
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||||
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||||
|
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||||
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
|
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||||
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
|
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||||
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||||
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||||
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
|
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||||
|
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||||
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
|
|||||||
BIN
Binary file not shown.
|
After Width: | Height: | Size: 261 KiB |
+9
-5
@@ -3,9 +3,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,12 +22,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||||
http.HandleFunc("/", proxyManager.HandleFunc)
|
gin.SetMode(mode)
|
||||||
|
} else {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyManager := proxy.New(config)
|
||||||
fmt.Println("llama-swap listening on " + *listenStr)
|
fmt.Println("llama-swap listening on " + *listenStr)
|
||||||
if err := http.ListenAndServe(*listenStr, nil); err != nil {
|
if err := proxyManager.Run(*listenStr); err != nil {
|
||||||
fmt.Printf("Error starting server: %v\n", err)
|
fmt.Printf("Server error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,11 +16,19 @@ func main() {
|
|||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
// Set up the handler function using the provided response message
|
responseMessageHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Set the header to text/plain
|
// Set the header to text/plain
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprintln(w, *responseMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the handler function using the provided response message
|
||||||
|
http.HandleFunc("/v1/chat/completions", responseMessageHandler)
|
||||||
|
http.HandleFunc("/v1/completions", responseMessageHandler)
|
||||||
|
http.HandleFunc("/test", responseMessageHandler)
|
||||||
|
|
||||||
|
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
fmt.Fprintln(w, *responseMessage)
|
fmt.Fprintln(w, *responseMessage)
|
||||||
|
|
||||||
// Get environment variables
|
// Get environment variables
|
||||||
@@ -39,7 +47,12 @@ func main() {
|
|||||||
w.Write([]byte(response))
|
w.Write([]byte(response))
|
||||||
})
|
})
|
||||||
|
|
||||||
address := ":" + *port // Address with the specified port
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||||
|
})
|
||||||
|
|
||||||
|
address := "127.0.0.1:" + *port // Address with the specified port
|
||||||
fmt.Printf("Server is listening on port %s\n", *port)
|
fmt.Printf("Server is listening on port %s\n", *port)
|
||||||
|
|
||||||
// Start the server and log any error if it occurs
|
// Start the server and log any error if it occurs
|
||||||
|
|||||||
+27
-14
@@ -14,6 +14,7 @@ type ModelConfig struct {
|
|||||||
Aliases []string `yaml:"aliases"`
|
Aliases []string `yaml:"aliases"`
|
||||||
Env []string `yaml:"env"`
|
Env []string `yaml:"env"`
|
||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
|
UnloadAfter int `yaml:"ttl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
@@ -21,26 +22,30 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
Models map[string]ModelConfig `yaml:"models"`
|
||||||
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
|
||||||
|
// map aliases to actual model IDs
|
||||||
|
aliases map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) FindConfig(modelName string) (ModelConfig, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
modelConfig, found := c.Models[modelName]
|
if _, found := c.Models[search]; found {
|
||||||
if found {
|
return search, true
|
||||||
return modelConfig, true
|
} else if name, found := c.aliases[search]; found {
|
||||||
|
return name, found
|
||||||
|
} else {
|
||||||
|
return "", false
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Search through aliases to find the right config
|
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||||
for _, config := range c.Models {
|
if realName, found := c.RealModelName(modelName); !found {
|
||||||
for _, alias := range config.Aliases {
|
return ModelConfig{}, "", false
|
||||||
if alias == modelName {
|
} else {
|
||||||
return config, true
|
return c.Models[realName], realName, true
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ModelConfig{}, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (*Config, error) {
|
||||||
@@ -59,6 +64,14 @@ func LoadConfig(path string) (*Config, error) {
|
|||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+49
-8
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoadConfig(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
// Create a temporary YAML file for testing
|
// Create a temporary YAML file for testing
|
||||||
tempDir, err := os.MkdirTemp("", "test-config")
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) {
|
|||||||
defer os.RemoveAll(tempDir)
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
content := `models:
|
content := `
|
||||||
|
models:
|
||||||
model1:
|
model1:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
proxy: "http://localhost:8080"
|
proxy: "http://localhost:8080"
|
||||||
@@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) {
|
|||||||
- "VAR1=value1"
|
- "VAR1=value1"
|
||||||
- "VAR2=value2"
|
- "VAR2=value2"
|
||||||
checkEndpoint: "/health"
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
@@ -50,14 +61,33 @@ healthCheckTimeout: 15
|
|||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
},
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: nil,
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, config)
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelConfigSanitizedCommand(t *testing.T) {
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
config := &ModelConfig{
|
config := &ModelConfig{
|
||||||
Cmd: `python model1.py \
|
Cmd: `python model1.py \
|
||||||
--arg1 value1 \
|
--arg1 value1 \
|
||||||
@@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindConfig(t *testing.T) {
|
func TestConfig_FindConfig(t *testing.T) {
|
||||||
|
|
||||||
|
// TODO?
|
||||||
|
// make make this shared between the different tests
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
@@ -88,25 +121,33 @@ func TestFindConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 10,
|
HealthCheckTimeout: 10,
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test finding a model by its name
|
// Test finding a model by its name
|
||||||
modelConfig, found := config.FindConfig("model1")
|
modelConfig, modelId, found := config.FindConfig("model1")
|
||||||
assert.True(t, found)
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", modelId)
|
||||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||||
|
|
||||||
// Test finding a model by its alias
|
// Test finding a model by its alias
|
||||||
modelConfig, found = config.FindConfig("m1")
|
modelConfig, modelId, found = config.FindConfig("m1")
|
||||||
assert.True(t, found)
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", modelId)
|
||||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||||
|
|
||||||
// Test finding a model that does not exist
|
// Test finding a model that does not exist
|
||||||
modelConfig, found = config.FindConfig("model3")
|
modelConfig, modelId, found = config.FindConfig("model3")
|
||||||
assert.False(t, found)
|
assert.False(t, found)
|
||||||
|
assert.Equal(t, "", modelId)
|
||||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSanitizeCommand(t *testing.T) {
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
// Test a simple command
|
// Test a simple command
|
||||||
args, err := SanitizeCommand("python model1.py")
|
args, err := SanitizeCommand("python model1.py")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
nextTestPort int = 12000
|
||||||
|
portMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check if the binary exists
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||||
|
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get the binary path
|
||||||
|
func getSimpleResponderPath() string {
|
||||||
|
goos := runtime.GOOS
|
||||||
|
goarch := runtime.GOARCH
|
||||||
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||||
|
portMutex.Lock()
|
||||||
|
defer portMutex.Unlock()
|
||||||
|
|
||||||
|
port := nextTestPort
|
||||||
|
nextTestPort++
|
||||||
|
|
||||||
|
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
|
||||||
|
// Create a process configuration
|
||||||
|
return ModelConfig{
|
||||||
|
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage),
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/ring"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LogMonitor struct {
|
||||||
|
clients map[chan []byte]bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
buffer *ring.Ring
|
||||||
|
bufferMu sync.RWMutex
|
||||||
|
|
||||||
|
// typically this can be os.Stdout
|
||||||
|
stdout io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogMonitor() *LogMonitor {
|
||||||
|
return NewLogMonitorWriter(os.Stdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||||
|
return &LogMonitor{
|
||||||
|
clients: make(map[chan []byte]bool),
|
||||||
|
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||||
|
stdout: stdout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = w.stdout.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.bufferMu.Lock()
|
||||||
|
bufferCopy := make([]byte, len(p))
|
||||||
|
copy(bufferCopy, p)
|
||||||
|
w.buffer.Value = bufferCopy
|
||||||
|
w.buffer = w.buffer.Next()
|
||||||
|
w.bufferMu.Unlock()
|
||||||
|
|
||||||
|
w.broadcast(p)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) GetHistory() []byte {
|
||||||
|
w.bufferMu.RLock()
|
||||||
|
defer w.bufferMu.RUnlock()
|
||||||
|
|
||||||
|
var history []byte
|
||||||
|
w.buffer.Do(func(p any) {
|
||||||
|
if p != nil {
|
||||||
|
if content, ok := p.([]byte); ok {
|
||||||
|
history = append(history, content...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return history
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Subscribe() chan []byte {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
ch := make(chan []byte, 100)
|
||||||
|
w.clients[ch] = true
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
delete(w.clients, ch)
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) broadcast(msg []byte) {
|
||||||
|
w.mu.RLock()
|
||||||
|
defer w.mu.RUnlock()
|
||||||
|
|
||||||
|
for client := range w.clients {
|
||||||
|
select {
|
||||||
|
case client <- msg:
|
||||||
|
default:
|
||||||
|
// If client buffer is full, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogMonitor(t *testing.T) {
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Test subscription
|
||||||
|
client1 := logMonitor.Subscribe()
|
||||||
|
client2 := logMonitor.Subscribe()
|
||||||
|
|
||||||
|
defer logMonitor.Unsubscribe(client1)
|
||||||
|
defer logMonitor.Unsubscribe(client2)
|
||||||
|
|
||||||
|
client1Messages := make([]byte, 0)
|
||||||
|
client2Messages := make([]byte, 0)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case data := <-client1:
|
||||||
|
client1Messages = append(client1Messages, data...)
|
||||||
|
case data := <-client2:
|
||||||
|
client2Messages = append(client2Messages, data...)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
logMonitor.Write([]byte("1"))
|
||||||
|
logMonitor.Write([]byte("2"))
|
||||||
|
logMonitor.Write([]byte("3"))
|
||||||
|
|
||||||
|
// Wait for the goroutine to finish
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Check the buffer
|
||||||
|
expectedHistory := "123"
|
||||||
|
history := string(logMonitor.GetHistory())
|
||||||
|
|
||||||
|
if history != expectedHistory {
|
||||||
|
t.Errorf("Expected history: %s, got: %s", expectedHistory, history)
|
||||||
|
}
|
||||||
|
|
||||||
|
c1Data := string(client1Messages)
|
||||||
|
if c1Data != expectedHistory {
|
||||||
|
t.Errorf("Client1 expected %s, got: %s", expectedHistory, c1Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
c2Data := string(client2Messages)
|
||||||
|
if c2Data != expectedHistory {
|
||||||
|
t.Errorf("Client2 expected %s, got: %s", expectedHistory, c2Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||||
|
// Create a new LogMonitor instance
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Prepare a message to write
|
||||||
|
msg := []byte("Hello, World!")
|
||||||
|
lenmsg := len(msg)
|
||||||
|
|
||||||
|
// Write the message to the LogMonitor
|
||||||
|
n, err := lm.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Write failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != lenmsg {
|
||||||
|
t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the original message
|
||||||
|
msg[0] = 'B' // This should not affect the buffer
|
||||||
|
|
||||||
|
// Get the history from the LogMonitor
|
||||||
|
history := lm.GetHistory()
|
||||||
|
|
||||||
|
// Check that the history contains the original message, not the modified one
|
||||||
|
expected := []byte("Hello, World!")
|
||||||
|
if !bytes.Equal(history, expected) {
|
||||||
|
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProxyManager struct {
|
|
||||||
sync.Mutex
|
|
||||||
|
|
||||||
config *Config
|
|
||||||
currentCmd *exec.Cmd
|
|
||||||
currentConfig ModelConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(config *Config) *ProxyManager {
|
|
||||||
return &ProxyManager{config: config}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
|
|
||||||
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints
|
|
||||||
|
|
||||||
if r.URL.Path == "/v1/chat/completions" {
|
|
||||||
// extracts the `model` from json body
|
|
||||||
pm.proxyChatRequest(w, r)
|
|
||||||
} else {
|
|
||||||
pm.proxyRequest(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) swapModel(requestedModel string) error {
|
|
||||||
pm.Lock()
|
|
||||||
defer pm.Unlock()
|
|
||||||
|
|
||||||
// find the model configuration matching requestedModel
|
|
||||||
modelConfig, found := pm.config.FindConfig(requestedModel)
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("could not find configuration for %s", requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// no need to swap llama.cpp instances
|
|
||||||
if pm.currentConfig.Cmd == modelConfig.Cmd {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// kill the current running one to swap it
|
|
||||||
if pm.currentCmd != nil {
|
|
||||||
pm.currentCmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
|
|
||||||
// wait for it to end
|
|
||||||
pm.currentCmd.Process.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.currentConfig = modelConfig
|
|
||||||
|
|
||||||
args, err := modelConfig.SanitizedCommand()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
|
||||||
}
|
|
||||||
cmd := exec.Command(args[0], args[1:]...)
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
cmd.Env = modelConfig.Env
|
|
||||||
|
|
||||||
err = cmd.Start()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
pm.currentCmd = cmd
|
|
||||||
|
|
||||||
if err := pm.checkHealthEndpoint(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) checkHealthEndpoint() error {
|
|
||||||
|
|
||||||
if pm.currentConfig.Proxy == "" {
|
|
||||||
return fmt.Errorf("no upstream available to check /health")
|
|
||||||
}
|
|
||||||
|
|
||||||
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
|
|
||||||
|
|
||||||
if checkEndpoint == "none" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep default behaviour
|
|
||||||
if checkEndpoint == "" {
|
|
||||||
checkEndpoint = "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := pm.currentConfig.Proxy
|
|
||||||
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
|
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
|
||||||
}
|
|
||||||
client := &http.Client{}
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(req.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
|
||||||
|
|
||||||
// if TCP dial can't connect any HTTP response after 5 seconds
|
|
||||||
// exit quickly.
|
|
||||||
if time.Since(startTime) > 5*time.Second {
|
|
||||||
return fmt.Errorf("health check endpoint took more than 5 seconds to respond")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Since(startTime) >= maxDuration {
|
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if time.Since(startTime) >= maxDuration {
|
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) {
|
|
||||||
bodyBytes, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var requestBody map[string]interface{}
|
|
||||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
|
||||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, ok := requestBody["model"].(string)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := pm.swapModel(model); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("unable to swap to model: %s", err.Error()), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
pm.proxyRequest(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if pm.currentConfig.Proxy == "" {
|
|
||||||
http.Error(w, "No upstream proxy", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := pm.currentConfig.Proxy
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header = r.Header
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
for k, vv := range resp.Header {
|
|
||||||
for _, v := range vv {
|
|
||||||
w.Header().Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
// faster than io.Copy when streaming
|
|
||||||
buf := make([]byte, 32*1024)
|
|
||||||
for {
|
|
||||||
n, err := resp.Body.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
|
||||||
http.Error(w, writeErr.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if flusher, ok := w.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,259 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Process struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
ID string
|
||||||
|
config ModelConfig
|
||||||
|
cmd *exec.Cmd
|
||||||
|
logMonitor *LogMonitor
|
||||||
|
healthCheckTimeout int
|
||||||
|
|
||||||
|
isRunning bool
|
||||||
|
lastRequestHandled time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||||
|
return &Process{
|
||||||
|
ID: ID,
|
||||||
|
config: config,
|
||||||
|
cmd: nil,
|
||||||
|
logMonitor: logMonitor,
|
||||||
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// start the process and check it for errors
|
||||||
|
func (p *Process) start() error {
|
||||||
|
p.Lock()
|
||||||
|
defer p.Unlock()
|
||||||
|
|
||||||
|
if p.isRunning {
|
||||||
|
return fmt.Errorf("process already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := p.config.SanitizedCommand()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.cmd = exec.Command(args[0], args[1:]...)
|
||||||
|
p.cmd.Stdout = p.logMonitor
|
||||||
|
p.cmd.Stderr = p.logMonitor
|
||||||
|
p.cmd.Env = p.config.Env
|
||||||
|
|
||||||
|
err = p.cmd.Start()
|
||||||
|
p.isRunning = true
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// watch for the command to exit
|
||||||
|
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
|
// monitor the command's exit status. Usually this happens if
|
||||||
|
// the process exited unexpectedly
|
||||||
|
go func() {
|
||||||
|
err := p.cmd.Wait()
|
||||||
|
if err != nil {
|
||||||
|
cancel(fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error()))
|
||||||
|
} else {
|
||||||
|
cancel(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.isRunning = false
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait a bit for process to start before checking the health endpoint
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
|
||||||
|
// wait for checkHealthEndpoint
|
||||||
|
if err := p.checkHealthEndpoint(cmdCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.UnloadAfter > 0 {
|
||||||
|
// start a goroutine to check every second if
|
||||||
|
// the process should be stopped
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
<-ticker.C
|
||||||
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
|
||||||
|
p.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) Stop() {
|
||||||
|
p.Lock()
|
||||||
|
defer p.Unlock()
|
||||||
|
|
||||||
|
if !p.isRunning || p.cmd == nil || p.cmd.Process == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
p.cmd.Process.Wait()
|
||||||
|
p.isRunning = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) IsRunning() bool {
|
||||||
|
return p.isRunning
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
|
||||||
|
if p.config.Proxy == "" {
|
||||||
|
return fmt.Errorf("no upstream available to check /health")
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||||
|
|
||||||
|
if checkEndpoint == "none" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep default behaviour
|
||||||
|
if checkEndpoint == "" {
|
||||||
|
checkEndpoint = "/health"
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyTo := p.config.Proxy
|
||||||
|
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||||
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
for {
|
||||||
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(cmdCtx, time.Second)
|
||||||
|
defer cancel()
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
|
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// check if the context was cancelled
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err := context.Cause(ctx)
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait a bit longer for TCP connection issues
|
||||||
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
|
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
} else {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl < 0 {
|
||||||
|
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl < 0 {
|
||||||
|
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !p.isRunning {
|
||||||
|
if err := p.start(); err != nil {
|
||||||
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
|
http.Error(w, errstr, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.lastRequestHandled = time.Now()
|
||||||
|
|
||||||
|
proxyTo := p.config.Proxy
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header = r.Header.Clone()
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
for k, vv := range resp.Header {
|
||||||
|
for _, v := range vv {
|
||||||
|
w.Header().Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
// faster than io.Copy when streaming
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
for {
|
||||||
|
n, err := resp.Body.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flusher, ok := w.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
expectedMessage := "testing91931"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// Create a process
|
||||||
|
process := NewProcess("test-process", 5, config, logMonitor)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// process is automatically started
|
||||||
|
assert.False(t, process.IsRunning())
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.True(t, process.IsRunning())
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||||
|
|
||||||
|
// Stop the process
|
||||||
|
process.Stop()
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "/", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Proxy the request
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
|
// should have automatically started the process again
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test that the automatic start returns the expected error type
|
||||||
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
|
// Create a process configuration
|
||||||
|
config := ModelConfig{
|
||||||
|
Cmd: "nonexistant-command",
|
||||||
|
Proxy: "http://127.0.0.1:9913",
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test that the process unloads after the TTL
|
||||||
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long auto unload TTL test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "I_sense_imminent_danger"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
assert.Equal(t, 0, config.UnloadAfter)
|
||||||
|
config.UnloadAfter = 3 // seconds
|
||||||
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
|
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Proxy the request (auto start)
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||||
|
|
||||||
|
assert.True(t, process.IsRunning())
|
||||||
|
|
||||||
|
// wait 5 seconds
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
assert.False(t, process.IsRunning())
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyManager struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
config *Config
|
||||||
|
currentProcesses map[string]*Process
|
||||||
|
logMonitor *LogMonitor
|
||||||
|
ginEngine *gin.Engine
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(config *Config) *ProxyManager {
|
||||||
|
pm := &ProxyManager{
|
||||||
|
config: config,
|
||||||
|
currentProcesses: make(map[string]*Process),
|
||||||
|
logMonitor: NewLogMonitor(),
|
||||||
|
ginEngine: gin.New(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up routes using the Gin engine
|
||||||
|
pm.ginEngine.POST("/v1/chat/completions", pm.proxyChatRequestHandler)
|
||||||
|
|
||||||
|
// Support legacy /v1/completions api, see issue #12
|
||||||
|
pm.ginEngine.POST("/v1/completions", pm.proxyChatRequestHandler)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||||
|
|
||||||
|
// in proxymanager_loghandlers.go
|
||||||
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||||
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
|
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||||
|
|
||||||
|
pm.ginEngine.NoRoute(pm.proxyNoRouteHandler)
|
||||||
|
|
||||||
|
// Disable console color for testing
|
||||||
|
gin.DisableConsoleColor()
|
||||||
|
|
||||||
|
return pm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) Run(addr ...string) error {
|
||||||
|
return pm.ginEngine.Run(addr...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||||
|
pm.ginEngine.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) StopProcesses() {
|
||||||
|
pm.Lock()
|
||||||
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
pm.stopProcesses()
|
||||||
|
}
|
||||||
|
|
||||||
|
// for internal usage
|
||||||
|
func (pm *ProxyManager) stopProcesses() {
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
pm.currentProcesses = make(map[string]*Process)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
|
data := []interface{}{}
|
||||||
|
for id := range pm.config.Models {
|
||||||
|
data = append(data, map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"object": "model",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"owned_by": "llama-swap",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the Content-Type header to application/json
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Encode the data as JSON and write it to the response writer
|
||||||
|
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||||
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||||
|
pm.Lock()
|
||||||
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
// Check if requestedModel contains a /
|
||||||
|
groupName, modelName := "", requestedModel
|
||||||
|
if idx := strings.Index(requestedModel, "/"); idx != -1 {
|
||||||
|
groupName = requestedModel[:idx]
|
||||||
|
modelName = requestedModel[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupName != "" {
|
||||||
|
if _, found := pm.config.Profiles[groupName]; !found {
|
||||||
|
return nil, fmt.Errorf("model group not found %s", groupName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// de-alias the real model name and get a real one
|
||||||
|
realModelName, found := pm.config.RealModelName(modelName)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exit early when already running, otherwise stop everything and swap
|
||||||
|
requestedProcessKey := groupName + "/" + realModelName
|
||||||
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
|
return process, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop all running models
|
||||||
|
pm.stopProcesses()
|
||||||
|
|
||||||
|
if groupName == "" {
|
||||||
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
|
processKey := groupName + "/" + modelID
|
||||||
|
pm.currentProcesses[processKey] = process
|
||||||
|
} else {
|
||||||
|
for _, modelName := range pm.config.Profiles[groupName] {
|
||||||
|
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||||
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
|
processKey := groupName + "/" + modelID
|
||||||
|
pm.currentProcesses[processKey] = process
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestedProcessKey should exist due to swap
|
||||||
|
return pm.currentProcesses[requestedProcessKey], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var requestBody map[string]interface{}
|
||||||
|
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
||||||
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model, ok := requestBody["model"].(string)
|
||||||
|
if !ok {
|
||||||
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("missing or invalid 'model' key"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if process, err := pm.swapModel(model); err != nil {
|
||||||
|
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
|
c.Request.Header.Del("transfer-encoding")
|
||||||
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
||||||
|
// since maps are unordered, just use the first available process if one exists
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
history := pm.logMonitor.GetHistory()
|
||||||
|
_, err := c.Writer.Write(history)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/plain")
|
||||||
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
|
ch := pm.logMonitor.Subscribe()
|
||||||
|
defer pm.logMonitor.Unsubscribe(ch)
|
||||||
|
|
||||||
|
notify := c.Request.Context().Done()
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, skipHistory := c.GetQuery("no-history")
|
||||||
|
// Send history first if not skipped
|
||||||
|
|
||||||
|
if !skipHistory {
|
||||||
|
history := pm.logMonitor.GetHistory()
|
||||||
|
if len(history) != 0 {
|
||||||
|
_, err := c.Writer.Write(history)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream new logs
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-ch:
|
||||||
|
_, err := c.Writer.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-notify:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
|
ch := pm.logMonitor.Subscribe()
|
||||||
|
defer pm.logMonitor.Unsubscribe(ch)
|
||||||
|
|
||||||
|
notify := c.Request.Context().Done()
|
||||||
|
|
||||||
|
// Send history first if not skipped
|
||||||
|
_, skipHistory := c.GetQuery("no-history")
|
||||||
|
if !skipHistory {
|
||||||
|
history := pm.logMonitor.GetHistory()
|
||||||
|
if len(history) != 0 {
|
||||||
|
c.SSEvent("message", string(history))
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream new logs
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-ch:
|
||||||
|
c.SSEvent("message", string(msg))
|
||||||
|
c.Writer.Flush()
|
||||||
|
case <-notify:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
|
_, exists := proxy.currentProcesses["/"+modelName]
|
||||||
|
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there's only one loaded model
|
||||||
|
assert.Len(t, proxy.currentProcesses, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there's two loaded models
|
||||||
|
assert.Len(t, proxy.currentProcesses, 2)
|
||||||
|
_, exists := proxy.currentProcesses["test/model1"]
|
||||||
|
assert.True(t, exists, "expected test/model1 key in currentProcesses")
|
||||||
|
|
||||||
|
_, exists = proxy.currentProcesses["test/model2"]
|
||||||
|
assert.True(t, exists, "expected test/model2 key in currentProcesses")
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user