Compare commits

..

55 Commits

Author SHA1 Message Date
Benson Wong e7af671d8e remove noisy debug print message 2025-05-19 15:36:15 -07:00
Benson Wong 8e62098eef add guard to avoid unnecessary logic in Process.Shutdown 2025-05-19 15:34:30 -07:00
choyuansu c260907415 Add linux install and uninstall shell scripts (#139)
Contribution for install, and uninstall llama-swap in linux.
2025-05-19 12:03:33 -07:00
Benson Wong b83a5fa291 make Failed stated recoverable (#137)
A process in the failed state can transition to stopped either by calling /unload or swapping to another model.
2025-05-16 19:54:44 -07:00
Benson Wong 6e2ff28d59 improve cmdStop docs [no ci] 2025-05-16 13:52:04 -07:00
Benson Wong a8b81f2799 Add stopCmd for custom stopping instructions (#136)
Allow configuration of how a model is stopped before swapping. Setting `cmdStop` in the configuration will override the default behaviour and enables better integration with other process/container managers like docker or podman.
2025-05-16 13:48:42 -07:00
Benson Wong f9ee7156dc update configuration examples for multiline yaml commands #133 2025-05-16 11:45:39 -07:00
fakezeta 2d00120781 Update proxymanager.go (#135) 2025-05-16 06:45:09 -07:00
Benson Wong afc9aef058 Fix #133 SanitizeCommand removes comments (#134) 2025-05-15 15:28:50 -07:00
Benson Wong d7b390df74 Add GH Action for Testing on Windows (#132)
* Add windows specific test changes
* Change the command line parsing library - Possible breaking changes for windows users!
2025-05-14 21:51:53 -07:00
Benson Wong 5025c2f1f3 Add GH windows tests (not working yet) 2025-05-14 19:58:22 -07:00
Benson Wong e3a0b013c1 add content length test for #131 2025-05-14 19:50:01 -07:00
Fadenfire f5763a94a0 Fix content length being incorrect when useModelName is used (#131)
* Fix content length being incorrect when useModelName is used
* Update c.Request.ContentLength as well
2025-05-14 19:37:54 -07:00
Benson Wong 8ada72eb57 Update issue templates 2025-05-14 16:36:32 -07:00
Benson Wong 2441b383d3 Make checking for process killed status more robust 2025-05-14 16:26:56 -07:00
Benson Wong 25f251699c Prevent StateFailed after SIGKILL (#129)
Closes #125
2025-05-14 10:47:35 -07:00
Benson Wong 7f37bcc6eb Improve testing around using SIGKILL (#127)
* Add test for SIGKILL of process
* silent TestProxyManager_RunningEndpoint debug output
* Ref #125
2025-05-13 21:21:52 -07:00
Benson Wong 519c3a4d22 Change /unload to not wait for inflight requests (#125)
Sometimes upstreams can accept HTTP but never respond causing requests
to build up waiting for a response. This can block Process.Stop() as
that waits for inflight requests to finish. This change refactors the
code to not wait when attempting to shutdown the process.
2025-05-13 11:39:19 -07:00
Benson Wong 9dc4bcb46c Add a concurrency limit to Process.ProxyRequest (#123) 2025-05-12 18:12:52 -07:00
Benson Wong cb876c143b update example config 2025-05-12 10:20:18 -07:00
Sam bc652709a5 Add config hot-reload (#106)
introduce --watch-config command line option to reload ProxyManager when configuration changes.
2025-05-11 17:37:00 -07:00
Thammachart Chinvarapon 9548931258 ci: re-enabled intel build pipeline (#121) 2025-05-11 00:19:57 -07:00
Benson Wong 5c5a5da664 Update README.md
removed extra section.
2025-05-06 06:59:15 -07:00
Benson Wong aa9ef59aa5 Create .coderabbit.yaml 2025-05-05 19:47:23 -07:00
Benson Wong 09e52c0500 Automatic Port Numbers (#105)
Add automatic port numbers assignment in configuration file. The string `${PORT}` will be substituted in model.cmd and model.proxy for an actual port number. This also allows model.proxy to be omitted from the configuration.
2025-05-05 17:07:43 -07:00
Benson Wong ca9063ffbe ensure aliases are unique (#116) 2025-05-05 15:34:18 -07:00
Benson Wong 21d7973d11 Improve content-length handling (#115)
ref: See #114

* Improve content-length handling
- Content length was not always being sent
- Add tests for content-length
2025-05-05 10:46:26 -07:00
Yi Hong Ang cc450e9c5f fix issue where proxy is still proxying with chunked transfer-encoding (#114) 2025-05-05 10:00:03 -07:00
Benson Wong 27465fe053 bug fix with missing early return statements fix #112 2025-05-05 09:32:44 -07:00
Benson Wong 9667989727 Disabling intel container build since it's been broken for weeks. 2025-05-04 21:39:42 -07:00
Benson Wong d9a1ddea0d Truncate web logs to 100K characters (#111)
* set log limit to 100K in browser
2025-05-02 23:43:21 -07:00
Benson Wong e7ab024ca0 small locking optimization 2025-05-02 23:18:07 -07:00
Benson Wong 448ccae959 Introduce Groups Feature (#107)
Groups allows more control over swapping behaviour when a model is requested. The new groups feature provides three ways to control swapping: within the group, swapping out other groups or keep the models in the group loaded persistently (never swapped out). 

Closes #96, #99 and #106.
2025-05-02 22:35:38 -07:00
Benson Wong ec0348e431 Reduce stale time for issues 2025-04-29 21:16:34 -07:00
Benson Wong 06eda7f591 tag all process logs with its ID (#103)
Makes identifying Process of log messages easier
2025-04-25 12:58:25 -07:00
Benson Wong 5fad24c16f Make checkHealthTimeout Interruptable during startup (#102)
interrupt and exit Process.start() early if the upstream process exits prematurely or unexpectedly.
2025-04-24 14:39:33 -07:00
Benson Wong 8404244fab Moderate security update for golang/x/net -> v0.38.0 2025-04-24 09:58:40 -07:00
Benson Wong 712cd01081 fix confusing INFO message [no ci] 2025-04-24 09:56:20 -07:00
Benson Wong 1f7aa359b1 Update header image
AI has finally made my dreams of llamas in funny clothing and stuck in
a claw machine waiting to be picked come true!
2025-04-23 13:02:12 -07:00
Benson Wong b138d6cf25 fix starhistory in README 2025-04-15 20:23:46 -07:00
Benson Wong fb7c808082 add timing for Process start, stop, total request time (#91) 2025-04-14 14:34:59 -07:00
Benson Wong a7e640b0f7 add aider example 2025-04-07 12:37:14 -07:00
Benson Wong 593604dfdc Show proxy and upstream logs in separate columns in logs UI 2025-04-05 10:36:54 -07:00
Benson Wong b8f888f864 Logging Improvements (#88)
This change revamps the internal logging architecture to be more flexible and descriptive. Previously all logs from both llama-swap and upstream services were mixed together. This makes it harder to troubleshoot and identify problems. This PR adds these new endpoints: 

- `/logs/stream/proxy` - just llama-swap's logs
- `/logs/stream/upstream` - stdout output from the upstream server
2025-04-04 21:01:33 -07:00
Benson Wong 192b2ae621 Remove no longer needed test 2025-04-04 14:46:01 -07:00
Benson Wong b7f8cb5094 Limit Access-Control-Allow-Origin to OPTIONS preflight requests #85 2025-04-04 14:44:35 -07:00
Benson Wong a23da6eb57 Sanitize CORS headers (#85)
Add sanitation step for `Access-Control-Allow-Headers` when echoing back user supplied headers
2025-04-01 08:43:53 -07:00
Grigorii Khvatskii 4c3aa40564 add graceful process termination on windows (#82) 2025-03-25 15:26:33 -07:00
Benson Wong 84e2c07a7e Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: 

- `Access-Control-Allow-Origin: *` is set for all requests 
- for pre-flight OPTIONS requests
  - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS`
  - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set
  - set `Access-Control-Max-Age: 86400` to that may improve performance 
- Add CORS tests to the proxy-manager
2025-03-25 15:24:43 -07:00
Benson Wong 680af28bcc Allow very permissive CORS headers (#77) 2025-03-20 15:50:21 -07:00
Benson Wong d94db42ffe fix bug checking incorrect error 2025-03-20 15:49:36 -07:00
Benson Wong 93cd83c55c add override for windows (#76) 2025-03-20 13:23:04 -07:00
Benson Wong 5565fca3ac add some badges to README 2025-03-19 11:25:06 -07:00
Benson Wong d625ab8d92 Refactor process state management (#70) (#73)
* add isValidStateTransition helper function
* Replace Process.setState() with Process.swapState()
* Refactor locking logic in Process
2025-03-15 17:14:03 -07:00
Benson Wong a3f82c140b tidy up config examples in README 2025-03-15 10:36:45 -07:00
36 changed files with 3136 additions and 956 deletions
+15
View File
@@ -0,0 +1,15 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
poem: false
review_status: true
collapse_walkthrough: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
+37
View File
@@ -0,0 +1,37 @@
---
name: Bug Report
about: Something is not working as expected...
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Expected behaviour**
A clear and concise description of what you expected to happen.
**Operating system and version**
- OS: (linux, osx, windows, freebsd, etc)
- GPUs: (list architecture)
**My Configuration**
```yaml
# copy / paste your configuration here
```
**Proxy Logs**
```
# copy / paste from /logs
```
**Upstream Logs**
```
# copy/paste from /logs
```
+4 -4
View File
@@ -13,11 +13,11 @@ jobs:
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 30
days-before-issue-stale: 14
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
repo-token: ${{ secrets.GITHUB_TOKEN }}
+50
View File
@@ -0,0 +1,50 @@
name: Windows CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
shell: bash
run: make simple-responder-windows
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all
shell: bash
run: make test-all
+18 -3
View File
@@ -1,6 +1,4 @@
# This workflow will build a golang project
name: CI
name: Linux CI
on:
push:
@@ -24,9 +22,26 @@ jobs:
with:
go-version: '1.23'
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all
run: make test-all
+13 -1
View File
@@ -15,4 +15,16 @@ builds:
- goos: freebsd
goarch: arm64
- goos: windows
goarch: arm64
goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+6 -2
View File
@@ -20,10 +20,10 @@ clean:
rm -rf $(BUILD_DIR)
test:
go test -short -v ./proxy
go test -short -v -count=1 ./proxy
test-all:
go test -v ./proxy
go test -v -count=1 ./proxy
# Build OSX binary
mac:
@@ -46,6 +46,10 @@ simple-responder:
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
simple-responder-windows:
@echo "Building simple responder for windows"
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
mkdir -p $(BUILD_DIR)
+106 -51
View File
@@ -1,4 +1,7 @@
![llama-swap header image](header.jpeg)
![llama-swap header image](header2.png)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
@@ -23,7 +26,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support
@@ -33,7 +36,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml
@@ -43,14 +46,14 @@ llama-swap's configuration is purposefully simple.
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
cmd: |
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
cmd: |
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
@@ -64,15 +67,31 @@ models:
# Default (and minimum) is 15 seconds
healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false
logRequests: true
# Valid log levels: debug, info (default), warn, error
logLevel: info
# Automatic Port Values
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
# when you use ${PORT} you can omit a custom model.proxy value, as it will
# default to http://localhost:${PORT}
# override the default port (5800) for automatic port values
startPort: 10001
# define valid model values and the upstream server start
models:
"llama":
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
# multiline for readability
cmd: |
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match
# can be omitted if you use an automatic ${PORT} in cmd
proxy: http://127.0.0.1:8999
# aliases names to use this model for
@@ -91,49 +110,83 @@ models:
# default: 0 = never unload model
ttl: 60
"qwen":
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# `useModelName` overrides the model name in the request
# and sends a specific name to the upstream server
useModelName: "qwen:qwq"
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
unlisted: true
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (v26.1.4+ required!)
"docker-llama":
proxy: "http://127.0.0.1:9790"
cmd: >
proxy: "http://127.0.0.1:${PORT}"
cmd: |
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# `useModelName` will send a specific model name to the upstream server
# overriding whatever was set in the request
"qwq":
proxy: http://127.0.0.1:11434
cmd: my-server
useModelName: "qwen:qwq"
# use a custom command to stop the model when swapping. By default
# this is SIGTERM on POSIX systems, and taskkill on Windows systems
# the ${PID} variable can be used in cmdStop, it will be automatically replaced
# with the PID of the running model
cmdStop: docker stop dockertest
# profiles make it easy to managing multi model (and gpu) configurations.
# Groups provide advanced controls over model swapping behaviour. Using groups
# some models can be kept loaded indefinitely, while others are swapped out.
#
# Tips:
# - each model must be listening on a unique address and port
# - the model name is in this format: "profile_name:model", like "coding:qwen"
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
- "qwen"
- "llama"
#
# - models must be defined above in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
#
# NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 is the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap controls the model swapping behaviour in within the group
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive controls how the group affects other groups
# - true: causes all other groups to unload their models when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
members:
- "llama"
- "qwen-unlisted"
# models in this group are never unloaded
"group2":
swap: false
exclusive: false
members:
- "docker-llama"
# (not defined above, here for example)
- "modelA"
- "modelB"
"forever":
# setting persistent to true causes the group to never be affected by the swapping behaviour of
# other groups. It is a shortcut to keeping some models always loaded.
persistent: true
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
```
### Use Case Examples
@@ -142,18 +195,13 @@ profiles:
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
## Configuration
llama-s
</details>
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker is the quickest way to try out llama-swap:
```
```shell
# use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
@@ -189,7 +237,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```
```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
@@ -204,7 +252,12 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Run the binary with `llama-swap --config path/to/config.yaml`
1. Run the binary with `llama-swap --config path/to/config.yaml`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source
@@ -219,13 +272,19 @@ Open the `http://<host>/logs` with your browser to get a web interface with stre
Of course, CLI access is also supported:
```
```shell
# sends up to the last 10KB of logs
curl http://host/logs'
# streams logs
# streams combined logs
curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
@@ -267,8 +326,4 @@ WantedBy=multi-user.target
## Star History
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
+23 -27
View File
@@ -1,17 +1,24 @@
# Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds
healthCheckTimeout: 15
healthCheckTimeout: 90
# Log HTTP requests helpful for troubleshoot, defaults to False
logRequests: true
# valid log levels: debug, info (default), warn, error
logLevel: debug
# creating a coding profile with models for code generation and general questions
groups:
coding:
swap: false
members:
- "qwen"
- "llama"
models:
"llama":
cmd: >
cmd: |
models/llama-server-osx
--port 9001
--port ${PORT}
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:9001
# list of model name aliases this llama.cpp instance can serve
aliases:
@@ -24,17 +31,15 @@ models:
ttl: 5
"qwen":
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9002
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
aliases:
- gpt-3.5-turbo
- gpt-3.5-turbo
# Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic":
proxy: http://127.0.0.1:9005
cmd: >
models/llama-server-osx --port 9005
cmd: |
models/llama-server-osx --port ${PORT}
-m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192
--batch-size 8192
@@ -46,19 +51,17 @@ models:
# Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker":
proxy: http://127.0.0.1:9006
cmd: >
models/llama-server-osx --port 9006
cmd: |
models/llama-server-osx --port ${PORT}
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192
--reranking
# Docker Support (v26.1.4+ required!)
"dockertest":
proxy: "http://127.0.0.1:9790"
cmd: >
cmd: |
docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
@@ -67,8 +70,7 @@ models:
env:
- CUDA_VISIBLE_DEVICES=0,1
- env1=hello
cmd: build/simple-responder --port 8999
proxy: http://127.0.0.1:8999
cmd: build/simple-responder --port ${PORT}
unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail
@@ -83,10 +85,4 @@ models:
"broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000
unlisted: true
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"
unlisted: true
+153
View File
@@ -0,0 +1,153 @@
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
## Here's what you you need:
- aider - [installation docs](https://aider.chat/docs/install.html)
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
- 24GB VRAM video card
## Running aider
The goal is getting this command line to work:
```sh
aider --architect \
--no-show-model-warnings \
--model openai/QwQ \
--editor-model openai/qwen-coder-32B \
--model-settings-file aider.model.settings.yml \
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1" \
```
Set `--openai-api-base` to the IP and port where your llama-swap is running.
## Create an aider model settings file
```yaml
# aider.model.settings.yml
#
# !!! important: model names must match llama-swap configuration names !!!
#
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
```
## llama-swap configuration
```yaml
# config.yaml
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
models:
"qwen-coder-32B":
proxy: "http://127.0.0.1:8999"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--cache-type-k q8_0 --cache-type-v q8_0
-ngl 99
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
"QwQ":
proxy: "http://127.0.0.1:9503"
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
--min-p 0.01 --top-k 40 --top-p 0.95
-ngl 99
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
```
## Advanced, Dual GPU Configuration
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
In llama-swap's configuration file:
1. add a `profiles` section with `aider` as the profile name
2. using the `env` field to specify the GPU IDs for each model
```yaml
# config.yaml
# Add a profile for aider
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=0"
proxy: "http://127.0.0.1:8999"
cmd: /path/to/llama-server ...
"QwQ":
# manually set the GPU to run on
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
cmd: /path/to/llama-server ...
```
Append the profile tag, `aider:`, to the model names in the model settings file
```yaml
# aider.model.settings.yml
- name: "openai/aider:QwQ"
weak_model_name: "openai/aider:qwen-coder-32B-aider"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
- name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B-aider"
```
Run aider with:
```sh
$ aider --architect \
--no-show-model-warnings \
--model openai/aider:QwQ \
--editor-model openai/aider:qwen-coder-32B \
--config aider.conf.yml \
--model-settings-file aider.model.settings.yml
--openai-api-key "sk-na" \
--openai-api-base "http://10.0.1.24:8080/v1"
```
@@ -0,0 +1,28 @@
# this makes use of llama-swap's profile feature to
# keep the architect and editor models in VRAM on different GPUs
- name: "openai/aider:QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/aider:qwen-coder-32B"
editor_model_name: "openai/aider:qwen-coder-32B"
- name: "openai/aider:qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/aider:qwen-coder-32B"
@@ -0,0 +1,26 @@
- name: "openai/QwQ"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.95
top_k: 40
presence_penalty: 0.1
repetition_penalty: 1
num_ctx: 16384
use_temperature: 0.6
reasoning_tag: think
weak_model_name: "openai/qwen-coder-32B"
editor_model_name: "openai/qwen-coder-32B"
- name: "openai/qwen-coder-32B"
edit_format: diff
extra_params:
max_tokens: 16384
top_p: 0.8
top_k: 20
repetition_penalty: 1.05
use_temperature: 0.6
reasoning_tag: think
editor_edit_format: editor-diff
editor_model_name: "openai/qwen-coder-32B"
+49
View File
@@ -0,0 +1,49 @@
healthCheckTimeout: 300
logLevel: debug
profiles:
aider:
- qwen-coder-32B
- QwQ
models:
"qwen-coder-32B":
env:
- "CUDA_VISIBLE_DEVICES=0"
aliases:
- coder
proxy: "http://127.0.0.1:8999"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 8999 --flash-attn --slots
--ctx-size 16000
--ctx-size-draft 16000
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
-ngl 99 -ngld 99
--draft-max 16 --draft-min 4 --draft-p-min 0.4
--cache-type-k q8_0 --cache-type-v q8_0
"QwQ":
env:
- "CUDA_VISIBLE_DEVICES=1"
proxy: "http://127.0.0.1:9503"
# set appropriate paths for your environment
cmd: >
/path/to/llama-server
--host 127.0.0.1 --port 9503
--flash-attn --metrics
--slots
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
--cache-type-k q8_0 --cache-type-v q8_0
--ctx-size 32000
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
--temp 0.6
--repeat-penalty 1.1
--dry-multiplier 0.5
--min-p 0.01
--top-k 40
--top-p 0.95
-ngl 99 -ngld 99
+3 -2
View File
@@ -3,8 +3,8 @@ module github.com/mostlygeek/llama-swap
go 1.23.0
require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
@@ -12,6 +12,7 @@ require (
)
require (
github.com/billziss-gh/golib v0.2.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
@@ -37,7 +38,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
+12 -18
View File
@@ -1,3 +1,5 @@
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@@ -23,6 +29,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
@@ -74,32 +82,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 351 KiB

+141 -11
View File
@@ -1,25 +1,34 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy"
)
var version string = "0"
var commit string = "abcd1234"
var date = "unknown"
var (
version string = "0"
commit string = "abcd1234"
date string = "unknown"
)
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
flag.Parse() // Parse the command-line flags
@@ -34,6 +43,10 @@ func main() {
os.Exit(1)
}
if len(config.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode)
} else {
@@ -42,18 +55,135 @@ func main() {
proxyManager := proxy.New(config)
// Setup channels for server management
reloadChan := make(chan *proxy.ProxyManager)
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create server with initial handler
srv := &http.Server{
Addr: *listenStr,
Handler: proxyManager,
}
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() {
<-sigChan
fmt.Println("Shutting down llama-swap")
proxyManager.Shutdown()
os.Exit(0)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Printf("Fatal server error: %v\n", err)
close(exitChan)
}
}()
fmt.Println("llama-swap listening on " + *listenStr)
if err := proxyManager.Run(*listenStr); err != nil {
fmt.Printf("Server error: %v\n", err)
os.Exit(1)
// Handle config reloads and signals
go func() {
currentManager := proxyManager
for {
select {
case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
// Now do a full shutdown to clear the process map
currentManager.Shutdown()
currentManager = newManager
srv.Handler = newManager
log.Println("Server handler updated with new config")
case sig := <-sigChan:
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentManager.Shutdown()
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
return
}
}
}()
// Start file watcher if requested
if *watchConfig {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
} else {
go watchConfigFileWithReload(absConfigPath, reloadChan)
}
}
// Wait for exit signal
<-exitChan
}
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
return
}
defer watcher.Close()
err = watcher.Add(configPath)
if err != nil {
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
return
}
log.Printf("Watching config file for changes: %s", configPath)
var debounceTimer *time.Timer
debounceDuration := 2 * time.Second
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
// We only care about writes to the specific config file
if event.Name == configPath && event.Has(fsnotify.Write) {
// Reset or start the debounce timer
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(debounceDuration, func() {
log.Printf("Config file modified: %s, reloading...", event.Name)
// Try up to 3 times with exponential backoff
var newConfig proxy.Config
var err error
for retries := 0; retries < 3; retries++ {
// Load new configuration
newConfig, err = proxy.LoadConfig(configPath)
if err == nil {
break
}
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
if retries < 2 {
time.Sleep(time.Duration(1<<retries) * time.Second)
}
}
if err != nil {
log.Printf("Failed to load new config after retries: %v", err)
return
}
// Create new ProxyManager with new config
newPM := proxy.New(newConfig)
reloadChan <- newPM
log.Println("Config reloaded successfully")
})
}
case err, ok := <-watcher.Errors:
if !ok {
log.Println("File watcher error channel closed.")
return
}
log.Printf("File watcher error: %v", err)
}
}
}
+48 -7
View File
@@ -26,6 +26,8 @@ func main() {
silent := flag.Bool("silent", false, "disable all logging")
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
flag.Parse() // Parse the command-line flags
// Create a new Gin router
@@ -33,14 +35,17 @@ func main() {
// Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.Header("Content-Type", "application/json")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
c.String(200, *responseMessage)
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
})
})
// for issue #62 to check model name strips profile slug
@@ -63,8 +68,11 @@ func main() {
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
})
})
// issue #41
@@ -104,6 +112,10 @@ func main() {
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
// expose some header values for testing
"h_content_type": c.GetHeader("Content-Type"),
"h_content_length": c.GetHeader("Content-Length"),
})
})
@@ -180,6 +192,10 @@ func main() {
log.SetOutput(io.Discard)
}
if !*silent {
fmt.Printf("My PID: %d\n", os.Getpid())
}
go func() {
log.Printf("simple-responder listening on %s\n", address)
// service connections
@@ -190,11 +206,36 @@ func main() {
// Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds.
quit := make(chan os.Signal, 1)
sigChan := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
countSigInt := 0
runloop:
for {
signal := <-sigChan
switch signal {
case syscall.SIGINT:
countSigInt++
if countSigInt > 1 {
break runloop
} else {
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
}
case syscall.SIGTERM:
if *ignoreSigTerm {
log.Println("Ignoring SIGTERM")
} else {
log.Println("Recieved SIGTERM, shutting down")
break runloop
}
default:
break runloop
}
}
log.Println("simple-responder shutting down")
}
+179 -13
View File
@@ -2,15 +2,22 @@ package proxy
import (
"fmt"
"io"
"os"
"runtime"
"sort"
"strconv"
"strings"
"github.com/google/shlex"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
)
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
@@ -18,20 +25,53 @@ type ModelConfig struct {
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
Models map[string]ModelConfig `yaml:"models"`
LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// map aliases to actual model IDs
aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
}
func (c *Config) RealModelName(search string) (string, bool) {
@@ -52,42 +92,168 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
}
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
func LoadConfig(path string) (Config, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
return Config{}, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
var config Config
err = yaml.Unmarshal(data, &config)
if err != nil {
return nil, err
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15
}
// set default port ranges
if config.StartPort == 0 {
// default to 5800
config.StartPort = 5800
} else if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
return &config, nil
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// iterate over the models and replace any ${PORT} with the next available port
if strings.Contains(modelConfig.Cmd, "${PORT}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
if modelConfig.Proxy == "" {
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
} else {
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
}
nextPort++
config.Models[modelId] = modelConfig
} else if modelConfig.Proxy == "" {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
}
func SanitizeCommand(cmdStr string) ([]string, error) {
// Remove trailing backslashes
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
args, err := shlex.Split(cmdStr)
if err != nil {
return nil, err
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
+42
View File
@@ -0,0 +1,42 @@
//go:build !windows
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
# comment 1
--arg3 123 \
# comment 2
--arg4 '"string in string"'
# this will get stripped out as well as the white space above
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
+182 -25
View File
@@ -3,6 +3,7 @@ package proxy
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -35,11 +36,32 @@ models:
aliases:
- "m2"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "mthree"
checkEndpoint: "/"
model4:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8082"
checkEndpoint: "/"
healthCheckTimeout: 15
profiles:
test:
- model1
- model2
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
forever:
exclusive: false
persistent: true
members:
- "model4"
`
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -52,7 +74,8 @@ profiles:
t.Fatalf("Failed to load config: %v", err)
}
expected := &Config{
expected := Config{
StartPort: 5800,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
@@ -68,6 +91,18 @@ profiles:
Env: nil,
CheckEndpoint: "/",
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: nil,
CheckEndpoint: "/",
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
},
},
HealthCheckTimeout: 15,
Profiles: map[string][]string{
@@ -77,6 +112,25 @@ profiles:
"m1": "model1",
"model-one": "model1",
"m2": "model2",
"mthree": "model3",
},
Groups: map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
},
}
@@ -87,6 +141,63 @@ profiles:
assert.Equal(t, "model1", realname)
}
func TestConfig_GroupMemberIsUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// a Contains as order of the map is not guaranteed
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
}
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- m1
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
aliases:
- m1
- m2
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// this is a contains because it could be `model1` or `model2` depending on the order
// go decided on the order of the map
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
@@ -147,30 +258,76 @@ func TestConfig_FindConfig(t *testing.T) {
assert.Equal(t, ModelConfig{}, modelConfig)
}
func TestConfig_SanitizeCommand(t *testing.T) {
func TestConfig_AutomaticPortAssignments(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
--arg3 123 \
--arg4 '"string in string"'
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
t.Run("Default Port Ranges", func(t *testing.T) {
content := ``
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
assert.Equal(t, 5800, config.StartPort)
})
t.Run("User specific port ranges", func(t *testing.T) {
content := `startPort: 1000`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 1000, config.StartPort)
})
t.Run("Invalid start port", func(t *testing.T) {
content := `startPort: abcd`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("start port must be greater than 1", func(t *testing.T) {
content := `startPort: -99`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("Automatic port assignments", func(t *testing.T) {
content := `
startPort: 5800
models:
model1:
cmd: svr --port ${PORT}
model2:
cmd: svr --port ${PORT}
proxy: "http://172.11.22.33:${PORT}"
model3:
cmd: svr --port 1999
proxy: "http://1.2.3.4:1999"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
})
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
content := `
models:
model1:
cmd: svr --port 111
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
})
}
+41
View File
@@ -0,0 +1,41 @@
//go:build windows
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_SanitizeCommand(t *testing.T) {
// does not support single quoted strings like in config_posix_test.go
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
-s
--arg3 123 \
# comment 2
--arg4 '"string in string"'
# this will get stripped out as well as the white space above
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"-s",
"--arg3", "123",
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
+24 -3
View File
@@ -14,6 +14,7 @@ import (
var (
nextTestPort int = 12000
portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout)
)
// Check if the binary exists
@@ -26,6 +27,17 @@ func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode)
switch os.Getenv("LOG_LEVEL") {
case "debug":
testLogger.SetLogLevel(LevelDebug)
case "warn":
testLogger.SetLogLevel(LevelWarn)
case "info":
testLogger.SetLogLevel(LevelInfo)
default:
testLogger.SetLogLevel(LevelWarn)
}
m.Run()
}
@@ -33,17 +45,26 @@ func TestMain(m *testing.M) {
func getSimpleResponderPath() string {
goos := runtime.GOOS
goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
if goos == "windows" {
return filepath.Join("..", "build", "simple-responder.exe")
} else {
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
func getTestPort() int {
portMutex.Lock()
defer portMutex.Unlock()
port := nextTestPort
nextTestPort++
return getTestSimpleResponderConfigPort(expectedMessage, port)
return port
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
+189 -75
View File
@@ -12,32 +12,65 @@
flex-direction: column;
font-family: "Courier New", Courier, monospace;
}
#log-controls {
margin: 0.5em;
.log-container {
display: flex;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
}
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
#log-stream {
flex: 1;
gap: 0.5em;
margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex;
flex-direction: column;
flex: 1;
min-width: 0;
transition: flex 0.3s ease;
}
.log-column.minimized {
flex: 0.1;
max-width: 50px;
border: 1px solid #777;
color: green;
}
.log-controls {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1;
padding: 1em;
background: #f4f4f4;
overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */
white-space: pre-wrap;
word-wrap: break-word;
min-height: 0;
}
.regex-error {
background-color: #ff0000 !important;
}
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
@@ -45,101 +78,182 @@
color: #fff;
}
#log-stream {
.log-stream {
background: #444;
color: #fff;
}
#log-controls input {
.log-controls input {
background: #555;
color: #fff;
border: 1px solid #777;
}
#log-controls button {
.log-controls button {
background: #555;
color: #fff;
border: 1px solid #777;
}
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
}
</style>
</head>
<body>
<pre id="log-stream">Waiting for logs...</pre>
<div id="log-controls">
<input type="text" id="filter-input" placeholder="regex filter">
<button id="clear-button">clear</button>
<div class="log-container">
<div class="log-column">
<h2>Proxy Logs</h2>
<div class="log-controls">
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div>
<script>
const logStream = document.getElementById('log-stream');
const filterInput = document.getElementById('filter-input');
var logData = "";
let regexFilter = null;
class LogStream {
constructor(streamElement, filterInput, clearButton, endpoint) {
this.streamElement = streamElement;
this.filterInput = filterInput;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
this.initialize();
}
}
// poor-ai's react ¯\_(ツ)_/¯
function render() {
if (regexFilter) {
const lines = logData.split('\n');
const filteredLines = lines.filter(line => {
return regexFilter === null || regexFilter.test(line);
initialize() {
this.filterInput.addEventListener('input', () => this.updateFilter());
this.clearButton.addEventListener('click', () => {
this.filterInput.value = "";
this.regexFilter = null;
this.render();
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
this.setupEventSource();
}
logStream.scrollTop = logStream.scrollHeight;
}
setupEventSource() {
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.logData = this.logData.slice(-1024 * 100);
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
};
// Initial connection
connect();
}
render() {
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
this.regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
this.regexFilter = null;
this.filterInput.classList.add('regex-error');
return;
}
} else {
regexFilter = null;
}
render();
this.render();
}
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
// Initialize both log streams
document.addEventListener('DOMContentLoaded', () => {
new LogStream(
document.getElementById('proxy-log-stream'),
document.getElementById('proxy-filter-input'),
document.getElementById('proxy-clear-button'),
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
});
});
});
setupEventSource();
updateFilter();
</script>
</body>
</html>
+90
View File
@@ -2,11 +2,21 @@ package proxy
import (
"container/ring"
"fmt"
"io"
"os"
"sync"
)
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct {
clients map[chan []byte]bool
mu sync.RWMutex
@@ -15,6 +25,10 @@ type LogMonitor struct {
// typically this can be os.Stdout
stdout io.Writer
// logging levels
level LogLevel
prefix string
}
func NewLogMonitor() *LogMonitor {
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
}
}
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
}
}
}
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+298 -142
View File
@@ -8,6 +8,8 @@ import (
"net/http"
"net/url"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
@@ -29,12 +31,26 @@ const (
StateShutdown ProcessState = ProcessState("shutdown")
)
type StopStrategy int
const (
StopImmediately StopStrategy = iota
StopWaitForInflightRequest
)
type Process struct {
ID string
config ModelConfig
cmd *exec.Cmd
logMonitor *LogMonitor
healthCheckTimeout int
ID string
config ModelConfig
cmd *exec.Cmd
// for p.cmd.Wait() select { ... }
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
@@ -49,56 +65,95 @@ type Process struct {
// for managing shutdown state
shutdownCtx context.Context
shutdownCancel context.CancelFunc
// for managing concurrency limits
concurrencyLimitSemaphore chan struct{}
// stop timeout waiting for graceful shutdown
gracefulStopTimeout time.Duration
// track that this happened
upstreamWasStoppedWithKill bool
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
}
return &Process{
ID: ID,
config: config,
cmd: nil,
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
ID: ID,
config: config,
cmd: nil,
cmdWaitChan: make(chan error, 1),
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
// concurrency limit
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
// stop timeout
gracefulStopTimeout: 5 * time.Second,
upstreamWasStoppedWithKill: false,
}
}
func (p *Process) setState(newState ProcessState) error {
// enforce valid state transitions
invalidTransition := false
if p.state == StateStopped {
// stopped -> starting
if newState != StateStarting {
invalidTransition = true
}
} else if p.state == StateStarting {
// starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateReady {
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
ErrInvalidStateTransition = errors.New("invalid state transition")
)
// swapState performs a compare and swap of the state atomically. It returns the current state
// and an error if the swap failed.
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != expectedState {
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
return p.state, ErrExpectedStateMismatch
}
if invalidTransition {
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.state = newState
return nil
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed:
return to == StateStopping
case StateShutdown:
return false // No transitions allowed from these states
}
return false
}
func (p *Process) CurrentState() ProcessState {
@@ -116,68 +171,67 @@ func (p *Process) start() error {
return fmt.Errorf("can not start(), upstream proxy missing")
}
// multiple start() calls will wait for the one that is actually starting to
// complete before proceeding.
// ===========
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
// ===========
// There is the possibility of a hard to replicate race condition where
// curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock()
// below, it's value has changed!
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// with the exclusive lock, check if p.state is StateStopped, which is the only valid state
// to transition from to StateReady
if p.state != StateStopped {
if p.state == StateReady {
return nil
} else {
return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state)
}
}
if err := p.setState(StateStarting); err != nil {
return err
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.logMonitor
p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.processLogger
p.cmd.Env = p.config.Env
err = p.cmd.Start()
// Set process state to failed
if err != nil {
p.setState(StateFailed)
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err)
}
// Capture the exit error for later signalling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
if p.upstreamWasStoppedWithKill {
p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr)
p.upstreamWasStoppedWithKill = false
return
}
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage:
// 1. The command exits unexpectedly
// 2. The health check fails
@@ -209,31 +263,51 @@ func (p *Process) start() error {
)
defer cancelHealthCheck()
// Health check loop
loop:
// Ready Check loop
for {
select {
case <-checkDeadline.Done():
p.setState(StateFailed)
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely but successfully")
}
}
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
} else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
}
}
}
<-time.After(5 * time.Second)
<-time.After(p.healthCheckLoopInterval)
}
}
@@ -244,7 +318,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) {
if p.state != StateReady {
if p.CurrentState() != StateReady {
return
}
@@ -252,7 +326,7 @@ func (p *Process) start() error {
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
@@ -260,92 +334,152 @@ func (p *Process) start() error {
}()
}
return p.setState(StateReady)
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
}
// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second)
// wait for any inflight requests before proceeding
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
p.inFlightRequests.Wait()
p.StopImmediately()
}
if err := p.setState(StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
currentState := p.CurrentState()
if currentState == StateFailed {
if curState, err := p.swapState(StateFailed, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Failed -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return
}
} else {
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return
}
}
// stop the process with a graceful exit timeout
p.stopCommand(p.gracefulStopTimeout)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
}
}
// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
// is in the state of starting, it will cancel it and shut it down. Once a process is in
// the StateShutdown state, it can not be started again.
func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.setState(StateShutdown)
p.shutdownCancel()
p.stopCommand(p.gracefulStopTimeout)
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
return
}
p.cmd.Process.Signal(syscall.SIGTERM)
// if err := p.terminateProcess(); err != nil {
// p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
// }
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
}
if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return
}
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
stopCmd.Env = p.config.Env
if err := stopCmd.Run(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
return
}
} else {
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
return
}
}
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.cmd.Process.Kill()
case err := <-sigtermNormal:
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
p.upstreamWasStoppedWithKill = true
if err := p.cmd.Process.Kill(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
}
case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that the cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
if err != nil {
if errno, ok := err.(syscall.Errno); ok {
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
}
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err)
}
}
}
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
}
@@ -370,6 +504,8 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
requestBeginTime := time.Now()
var startDuration time.Duration
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
@@ -378,6 +514,14 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}
select {
case p.concurrencyLimitSemaphore <- struct{}{}:
defer func() { <-p.concurrencyLimitSemaphore }()
default:
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
@@ -386,11 +530,13 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// start the process on demand
if p.CurrentState() != StateReady {
beginStartTime := time.Now()
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway)
return
}
startDuration = time.Since(beginStartTime)
}
proxyTo := p.config.Proxy
@@ -401,6 +547,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}
req.Header = r.Header.Clone()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
@@ -434,4 +586,8 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}
}
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
}
+197 -35
View File
@@ -2,10 +2,10 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"runtime"
"sync"
"testing"
"time"
@@ -13,13 +13,26 @@ import (
"github.com/stretchr/testify/assert"
)
var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, logMonitor)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
@@ -52,11 +65,10 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, logMonitor)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop()
var wg sync.WaitGroup
@@ -84,7 +96,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, NewLogMonitor())
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@@ -109,7 +121,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
defer process.Stop()
// this should take 4 seconds
@@ -151,7 +163,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
@@ -178,7 +190,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
process := NewProcess("t", 10, config, debugLogger, debugLogger)
defer process.Stop()
results := map[string]string{
@@ -225,39 +237,40 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
}
}
func TestSetState(t *testing.T) {
func TestProcess_SwapState(t *testing.T) {
tests := []struct {
name string
currentState ProcessState
expectedState ProcessState
newState ProcessState
expectedError error
expectedResult ProcessState
}{
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := &Process{
state: test.currentState,
}
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
p.state = test.currentState
err := p.setState(test.newState)
resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil {
@@ -268,8 +281,8 @@ func TestSetState(t *testing.T) {
}
}
if p.state != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
if resultState != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
}
})
}
@@ -280,7 +293,6 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test")
}
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
@@ -288,13 +300,16 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown
var wg sync.WaitGroup
go func() {
defer wg.Done()
<-time.After(time.Second * 2)
<-time.After(time.Millisecond * 500)
process.Shutdown()
}()
wg.Add(1)
@@ -306,3 +321,150 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}
func TestProcess_ConcurrencyLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping long concurrency limit test")
}
expectedMessage := "concurrency_limit_test"
config := getTestSimpleResponderConfig(expectedMessage)
// only allow 1 concurrent request at a time
config.ConcurrencyLimit = 1
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
defer process.Stop()
// launch a goroutine first to take up the semaphore
go func() {
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req1)
assert.Equal(t, http.StatusOK, w.Code)
}()
// let the goroutine start
<-time.After(time.Millisecond * 25)
denied := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
func TestProcess_StopImmediately(t *testing.T) {
expectedMessage := "test_stop_immediate"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) {
expectedMessage := "test_sigkill"
binaryPath := getSimpleResponderPath()
port := getTestPort()
config := ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
// reduce to make testing go faster
process.gracefulStopTimeout = time.Second
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
waitChan := make(chan struct{})
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
// StatusOK because that was already sent before the kill
assert.Equal(t, http.StatusOK, w.Code)
// unexpected EOF because the kill happened, the "1" is sent before the kill
// then the unexpected EOF is sent after the kill
if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else {
assert.Contains(t, w.Body.String(), "unexpected EOF")
}
close(waitChan)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
// the request should have been interrupted by SIGKILL
<-waitChan
}
func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
}
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
+114
View File
@@ -0,0 +1,114 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sync"
)
type ProcessGroup struct {
sync.Mutex
config Config
id string
swap bool
exclusive bool
persistent bool
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// map of current processes
processes map[string]*Process
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
}
pg := &ProcessGroup{
id: id,
config: config,
swap: groupConfig.Swap,
exclusive: groupConfig.Exclusive,
persistent: groupConfig.Persistent,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
processes: make(map[string]*Process),
}
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
return pg
}
// ProxyRequest proxies a request to the specified model
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
if !pg.HasMember(modelID) {
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
}
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
pg.lastUsedProcess = modelID
}
pg.Unlock()
}
pg.processes[modelID].ProxyRequest(writer, request)
return nil
}
func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
if len(pg.processes) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
}(process)
}
wg.Wait()
}
func (pg *ProcessGroup) Shutdown() {
var wg sync.WaitGroup
for _, process := range pg.processes {
wg.Add(1)
go func(process *Process) {
defer wg.Done()
process.Shutdown()
}(process)
}
wg.Wait()
}
+96
View File
@@ -0,0 +1,96 @@
package proxy
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model2"},
},
"G2": {
Swap: false,
Exclusive: true,
Members: []string{"model3", "model4"},
},
},
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
func TestProcessGroup_HasMember(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model1"))
assert.True(t, pg.HasMember("model2"))
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model3", "model4"}
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
})
}
// make sure all the processes are running
for _, process := range pg.processes {
assert.Equal(t, StateReady, process.CurrentState())
}
}
+234 -207
View File
@@ -7,6 +7,7 @@ import (
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
@@ -25,61 +26,116 @@ const (
type ProxyManager struct {
sync.Mutex
config *Config
currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine
config Config
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
processGroups map[string]*ProcessGroup
}
func New(config *Config) *ProxyManager {
pm := &ProxyManager{
config: config,
currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
}
func New(config Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
// see: https://github.com/mostlygeek/llama-swap/issues/42
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
processGroups: make(map[string]*ProcessGroup),
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
@@ -104,6 +160,8 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
@@ -139,63 +197,77 @@ func New(config *Config) *ProxyManager {
// Disable console color for testing
gin.DisableConsoleColor()
return pm
}
func (pm *ProxyManager) Run(addr ...string) error {
return pm.ginEngine.Run(addr...)
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
// ServeHTTP implements http.Handler interface
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
}
func (pm *ProxyManager) StopProcesses() {
// StopProcesses acquires a lock and stops all running upstream processes.
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
pm.stopProcesses()
}
// for internal usage
func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
}
// stop Processes in parallel
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(process *Process) {
go func(processGroup *ProcessGroup) {
defer wg.Done()
process.Stop()
}(process)
processGroup.StopProcesses(strategy)
}(processGroup)
}
wg.Wait()
pm.currentProcesses = make(map[string]*Process)
wg.Wait()
}
// Shutdown is called to shutdown all upstream processes
// when llama-swap is shutting down.
// Shutdown stops all processes managed by this ProxyManager
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
// shutdown process in parallel
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
var wg sync.WaitGroup
for _, process := range pm.currentProcesses {
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(process *Process) {
go func(processGroup *ProcessGroup) {
defer wg.Done()
process.Shutdown()
}(process)
processGroup.Shutdown()
}(processGroup)
}
wg.Wait()
}
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses(StopWaitForInflightRequest)
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id, modelConfig := range pm.config.Models {
@@ -225,78 +297,6 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
}
}
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", profileName)
}
}
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// check if model is part of the profile
if profileName != "" {
found := false
for _, item := range pm.config.Profiles[profileName] {
if item == realModelName {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
}
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id")
@@ -305,13 +305,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return
}
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
} else {
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
processGroup, _, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
@@ -332,7 +334,33 @@ func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
// Iterate over sorted keys
for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
// Get process state
processGroup := pm.findGroupByModelName(modelID)
var state string
if processGroup != nil {
process := processGroup.processes[modelID]
if process != nil {
var stateStr string
switch process.CurrentState() {
case StateReady:
stateStr = "Ready"
case StateStarting:
stateStr = "Starting"
case StateStopping:
stateStr = "Stopping"
case StateFailed:
stateStr = "Failed"
case StateShutdown:
stateStr = "Shutdown"
case StateStopped:
stateStr = "Stopped"
default:
stateStr = "Unknown"
}
state = stateStr
}
}
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a> - %s</li>", modelID, modelID, state))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
@@ -349,50 +377,40 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
process, err := pm.swapModel(requestedModel)
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -406,15 +424,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return
}
// Swap to the requested model
process, err := pm.swapModel(requestedModel)
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// Get profile name and model name from the requested model
profileName, modelName := splitRequestedModel(requestedModel)
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
@@ -422,10 +441,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
@@ -486,8 +508,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
process.ProxyRequest(c.Writer, modifiedReq)
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -501,7 +531,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses()
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
}
@@ -509,14 +539,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, process := range pm.currentProcesses {
// Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
}
}
// Put the results under the `running` key.
@@ -527,15 +558,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
func splitRequestedModel(requestedModel string) (string, string) {
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
for _, group := range pm.processGroups {
if group.HasMember(modelName) {
return group
}
}
return profileName, modelName
return nil
}
+37 -8
View File
@@ -9,7 +9,6 @@ import (
)
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html
@@ -28,7 +27,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
}
} else {
c.Header("Content-Type", "text/plain")
history := pm.logMonitor.GetHistory()
history := pm.muxLogger.GetHistory()
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
@@ -42,8 +41,14 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher)
@@ -56,7 +61,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped
if !skipHistory {
history := pm.logMonitor.GetHistory()
history := logger.GetHistory()
if len(history) != 0 {
c.Writer.Write(history)
flusher.Flush()
@@ -85,15 +90,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := pm.logMonitor.GetHistory()
history := logger.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
@@ -111,3 +122,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
}
}
}
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+304 -322
View File
@@ -8,85 +8,120 @@ import (
"mime/multipart"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1, model2},
LogLevel: "error",
Groups: map[string]GroupConfig{
"G1": {
Swap: true,
Exclusive: false,
Members: []string{"model1"},
},
"G2": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
},
}
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
tests := []string{"model1", "model2"}
for _, requestedModel := range tests {
t.Run(requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel)
})
}
// make sure there's two loaded models
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
}
// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model2"},
},
},
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"}
for _, requestedModel := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
assert.Contains(t, w.Body.String(), requestedModel)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
}
// When a request for a different model comes in ProxyManager should wait until
@@ -96,17 +131,18 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test")
}
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
results := map[string]string{}
@@ -122,15 +158,16 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"]
mu.Unlock()
}(key)
@@ -146,13 +183,14 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
}
func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{
config := Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -163,7 +201,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
w := httptest.NewRecorder()
// Call the listModelsHandler
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
// Check the response status code
assert.Equal(t, http.StatusOK, w.Code)
@@ -213,50 +251,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
assert.Empty(t, expectedModels, "not all expected models were returned")
}
func TestProxyManager_ProfileNonMember(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileMemberName := ProcessKeyName("test", model1)
profileNonMemberName := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1},
},
}
proxy := New(config)
defer proxy.StopProcesses()
// actual member of profile
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "model1")
}
// actual model, but non-member will 404
{
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
}
func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -268,23 +262,27 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/"
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2", "model3"},
},
Models: map[string]ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": model3Config,
},
}
LogLevel: "error",
Groups: map[string]GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
},
},
})
proxy := New(config)
// Start all the processes
var wg sync.WaitGroup
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
for _, modelName := range []string{"model1", "model2", "model3"} {
wg.Add(1)
go func(modelName string) {
defer wg.Done()
@@ -292,11 +290,10 @@ func TestProxyManager_Shutdown(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
// send a request to trigger the proxy to load
proxy.HandlerFunc(w, req)
// send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
//fmt.Println(w.Code, w.Body.String())
}(modelName)
}
@@ -308,64 +305,43 @@ func TestProxyManager_Shutdown(t *testing.T) {
}
func TestProxyManager_Unload(t *testing.T) {
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
}
LogLevel: "error",
})
proxy := New(config)
proc, err := proxy.swapModel("model1")
assert.NoError(t, err)
assert.NotNil(t, proc)
assert.Len(t, proxy.currentProcesses, 1)
req := httptest.NewRequest("GET", "/unload", nil)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name
func TestProxyManager_StripProfileSlug(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
// give it a bit of time to stop
<-time.After(time.Millisecond * 250)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
LogLevel: "warn",
})
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
@@ -377,12 +353,12 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -400,13 +376,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
proxy.ServeHTTP(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
@@ -420,224 +396,230 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
testCases := []struct {
name string
modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte("TheExpectedModel"))
assert.NoError(t, err)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tc.modelInput))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
defer proxy.StopProcesses(StopWaitForInflightRequest)
requestedModel := "model1"
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
// make sure the content length was set correctly
// simple-responder will return the content length it got in the response
body := w.Body.Bytes()
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
})
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
tests := []struct {
description string
requestedModel string
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_Upstream(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}
+43
View File
@@ -0,0 +1,43 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+77
View File
@@ -0,0 +1,77 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}
+189
View File
@@ -0,0 +1,189 @@
#!/bin/sh
# This script installs llama-swap on Linux.
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
status() { echo ">>> $*" >&2; }
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
warning() { echo "${red}WARNING:${plain} $*"; }
available() { command -v $1 >/dev/null; }
require() {
local MISSING=''
for TOOL in $*; do
if ! available $TOOL; then
MISSING="$MISSING $TOOL"
fi
done
echo $MISSING
}
SUDO=
if [ "$(id -u)" -ne 0 ]; then
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
NEEDS=$(require curl tee jq tar)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) error "Unsupported architecture: $ARCH" ;;
esac
IS_WSL2=false
KERN=$(uname -r)
case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;;
esac
download_binary() {
ASSET_NAME="linux_$ARCH"
# Fetch the latest release info and extract the matching asset URL
DL_URL=$(curl -s "https://api.github.com/repos/mostlygeek/llama-swap/releases/latest" | \
jq -r --arg name "$ASSET_NAME" \
'.assets[] | select(.name | contains($name)) | .browser_download_url')
# Check if a URL was successfully extracted
if [ -z "$DL_URL" ]; then
error "No matching asset found with name containing '$ASSET_NAME'."
fi
status "Downloading Linux $ARCH binary"
curl -s -L "$DL_URL" | $SUDO tar -xzf - -C /usr/local/bin llama-swap
}
download_binary
configure_systemd() {
if ! id llama-swap >/dev/null 2>&1; then
status "Creating llama-swap user..."
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
fi
if getent group render >/dev/null 2>&1; then
status "Adding llama-swap user to render group..."
$SUDO usermod -a -G render llama-swap
fi
if getent group video >/dev/null 2>&1; then
status "Adding llama-swap user to video group..."
$SUDO usermod -a -G video llama-swap
fi
if getent group docker >/dev/null 2>&1; then
status "Adding llama-swap user to docker group..."
$SUDO usermod -a -G docker llama-swap
fi
status "Adding current user to llama-swap group..."
$SUDO usermod -a -G llama-swap $(whoami)
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
status "Creating default config.yaml..."
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
# default 15s likely to fail for default models due to downloading models
healthCheckTimeout: 60
models:
"qwen2.5":
cmd: |
docker run
--rm
-p \${PORT}:8080
--name qwen2.5
ghcr.io/ggml-org/llama.cpp:server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
cmdStop: docker stop qwen2.5
"smollm2":
cmd: |
docker run
--rm
-p \${PORT}:8080
--name smollm2
ghcr.io/ggml-org/llama.cpp:server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
cmdStop: docker stop smollm2
EOF
fi
status "Creating llama-swap systemd service..."
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
[Unit]
Description=llama-swap
After=network.target
[Service]
User=llama-swap
Group=llama-swap
# set this to match your environment
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config
Restart=on-failure
RestartSec=3
StartLimitBurst=3
StartLimitInterval=30
[Install]
WantedBy=multi-user.target
EOF
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
case $SYSTEMCTL_RUNNING in
running|degraded)
status "Enabling and starting llama-swap service..."
$SUDO systemctl daemon-reload
$SUDO systemctl enable llama-swap
start_service() { $SUDO systemctl restart llama-swap; }
trap start_service EXIT
;;
*)
warning "systemd is not running"
if [ "$IS_WSL2" = true ]; then
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
fi
;;
esac
}
if available systemctl; then
configure_systemd
fi
install_success() {
status 'The llama-swap API is now available at 127.0.0.1:8080.'
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
status 'Install complete.'
}
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
exit 0
fi
install_success
+68
View File
@@ -0,0 +1,68 @@
#!/bin/sh
# This script uninstalls llama-swap on Linux.
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
status() { echo ">>> $*" >&2; }
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
warning() { echo "${red}WARNING:${plain} $*"; }
available() { command -v $1 >/dev/null; }
SUDO=
if [ "$(id -u)" -ne 0 ]; then
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
configure_systemd() {
status "Stopping llama-swap service..."
$SUDO systemctl stop llama-swap
status "Disabling llama-swap service..."
$SUDO systemctl disable llama-swap
}
if available systemctl; then
configure_systemd
fi
if available llama-swap; then
status "Removing llama-swap binary..."
$SUDO rm $(which llama-swap)
fi
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
while true; do
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
read answer
case "$answer" in
[Yy]* )
$SUDO rm -r /usr/share/llama-swap
break
;;
[Nn]* | "" )
break
;;
* )
echo "Invalid input. Please enter y or n."
;;
esac
done
fi
if id llama-swap >/dev/null 2>&1; then
status "Removing llama-swap user..."
$SUDO userdel llama-swap
fi
if getent group llama-swap >/dev/null 2>&1; then
status "Removing llama-swap group..."
$SUDO groupdel llama-swap
fi