Compare commits

...

25 Commits

Author SHA1 Message Date
Benson Wong 02aee4e86d remove noisy debug print message 2025-05-20 10:43:10 -07:00
Benson Wong f45896d395 add guard to avoid unnecessary logic in Process.Shutdown 2025-05-20 10:43:09 -07:00
choyuansu f7e46a359f Add link to unload endpoint in upstream list (#140)
* Add link to open /unload
2025-05-20 08:31:44 -07:00
choyuansu c260907415 Add linux install and uninstall shell scripts (#139)
Contribution for install, and uninstall llama-swap in linux.
2025-05-19 12:03:33 -07:00
Benson Wong b83a5fa291 make Failed stated recoverable (#137)
A process in the failed state can transition to stopped either by calling /unload or swapping to another model.
2025-05-16 19:54:44 -07:00
Benson Wong 6e2ff28d59 improve cmdStop docs [no ci] 2025-05-16 13:52:04 -07:00
Benson Wong a8b81f2799 Add stopCmd for custom stopping instructions (#136)
Allow configuration of how a model is stopped before swapping. Setting `cmdStop` in the configuration will override the default behaviour and enables better integration with other process/container managers like docker or podman.
2025-05-16 13:48:42 -07:00
Benson Wong f9ee7156dc update configuration examples for multiline yaml commands #133 2025-05-16 11:45:39 -07:00
fakezeta 2d00120781 Update proxymanager.go (#135) 2025-05-16 06:45:09 -07:00
Benson Wong afc9aef058 Fix #133 SanitizeCommand removes comments (#134) 2025-05-15 15:28:50 -07:00
Benson Wong d7b390df74 Add GH Action for Testing on Windows (#132)
* Add windows specific test changes
* Change the command line parsing library - Possible breaking changes for windows users!
2025-05-14 21:51:53 -07:00
Benson Wong 5025c2f1f3 Add GH windows tests (not working yet) 2025-05-14 19:58:22 -07:00
Benson Wong e3a0b013c1 add content length test for #131 2025-05-14 19:50:01 -07:00
Fadenfire f5763a94a0 Fix content length being incorrect when useModelName is used (#131)
* Fix content length being incorrect when useModelName is used
* Update c.Request.ContentLength as well
2025-05-14 19:37:54 -07:00
Benson Wong 8ada72eb57 Update issue templates 2025-05-14 16:36:32 -07:00
Benson Wong 2441b383d3 Make checking for process killed status more robust 2025-05-14 16:26:56 -07:00
Benson Wong 25f251699c Prevent StateFailed after SIGKILL (#129)
Closes #125
2025-05-14 10:47:35 -07:00
Benson Wong 7f37bcc6eb Improve testing around using SIGKILL (#127)
* Add test for SIGKILL of process
* silent TestProxyManager_RunningEndpoint debug output
* Ref #125
2025-05-13 21:21:52 -07:00
Benson Wong 519c3a4d22 Change /unload to not wait for inflight requests (#125)
Sometimes upstreams can accept HTTP but never respond causing requests
to build up waiting for a response. This can block Process.Stop() as
that waits for inflight requests to finish. This change refactors the
code to not wait when attempting to shutdown the process.
2025-05-13 11:39:19 -07:00
Benson Wong 9dc4bcb46c Add a concurrency limit to Process.ProxyRequest (#123) 2025-05-12 18:12:52 -07:00
Benson Wong cb876c143b update example config 2025-05-12 10:20:18 -07:00
Sam bc652709a5 Add config hot-reload (#106)
introduce --watch-config command line option to reload ProxyManager when configuration changes.
2025-05-11 17:37:00 -07:00
Thammachart Chinvarapon 9548931258 ci: re-enabled intel build pipeline (#121) 2025-05-11 00:19:57 -07:00
Benson Wong 5c5a5da664 Update README.md
removed extra section.
2025-05-06 06:59:15 -07:00
Benson Wong aa9ef59aa5 Create .coderabbit.yaml 2025-05-05 19:47:23 -07:00
27 changed files with 1075 additions and 215 deletions
+15
View File
@@ -0,0 +1,15 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
poem: false
review_status: true
collapse_walkthrough: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
+37
View File
@@ -0,0 +1,37 @@
---
name: Bug Report
about: Something is not working as expected...
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Expected behaviour**
A clear and concise description of what you expected to happen.
**Operating system and version**
- OS: (linux, osx, windows, freebsd, etc)
- GPUs: (list architecture)
**My Configuration**
```yaml
# copy / paste your configuration here
```
**Proxy Logs**
```
# copy / paste from /logs
```
**Upstream Logs**
```
# copy/paste from /logs
```
+1 -2
View File
@@ -15,8 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
#platform: [intel, cuda, vulkan, cpu, musa] platform: [intel, cuda, vulkan, cpu, musa]
platform: [cuda, vulkan, cpu, musa]
fail-fast: false fail-fast: false
steps: steps:
- name: Checkout code - name: Checkout code
+50
View File
@@ -0,0 +1,50 @@
name: Windows CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
shell: bash
run: make simple-responder-windows
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all
shell: bash
run: make test-all
+18 -3
View File
@@ -1,6 +1,4 @@
# This workflow will build a golang project name: Linux CI
name: CI
on: on:
push: push:
@@ -24,9 +22,26 @@ jobs:
with: with:
go-version: '1.23' go-version: '1.23'
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping # necessary for testing proxy/Process swapping
- name: Create simple-responder - name: Create simple-responder
run: make simple-responder run: make simple-responder
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all - name: Test all
run: make test-all run: make test-all
+6 -2
View File
@@ -20,10 +20,10 @@ clean:
rm -rf $(BUILD_DIR) rm -rf $(BUILD_DIR)
test: test:
go test -short -v ./proxy go test -short -v -count=1 ./proxy
test-all: test-all:
go test -v ./proxy go test -v -count=1 ./proxy
# Build OSX binary # Build OSX binary
mac: mac:
@@ -46,6 +46,10 @@ simple-responder:
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
simple-responder-windows:
@echo "Building simple responder for windows"
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
# Ensure build directory exists # Ensure build directory exists
$(BUILD_DIR): $(BUILD_DIR):
mkdir -p $(BUILD_DIR) mkdir -p $(BUILD_DIR)
+19 -13
View File
@@ -46,14 +46,14 @@ llama-swap's configuration is purposefully simple.
models: models:
"qwen2.5": "qwen2.5":
proxy: "http://127.0.0.1:9999" proxy: "http://127.0.0.1:9999"
cmd: > cmd: |
/app/llama-server /app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M -hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999 --port 9999
"smollm2": "smollm2":
proxy: "http://127.0.0.1:9999" proxy: "http://127.0.0.1:9999"
cmd: > cmd: |
/app/llama-server /app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M -hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999 --port 9999
@@ -82,7 +82,7 @@ startPort: 10001
models: models:
"llama": "llama":
# multiline for readability # multiline for readability
cmd: > cmd: |
llama-server --port 8999 llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf --model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
@@ -123,12 +123,18 @@ models:
# Docker Support (v26.1.4+ required!) # Docker Support (v26.1.4+ required!)
"docker-llama": "docker-llama":
proxy: "http://127.0.0.1:${PORT}" proxy: "http://127.0.0.1:${PORT}"
cmd: > cmd: |
docker run --name dockertest docker run --name dockertest
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models --init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# use a custom command to stop the model when swapping. By default
# this is SIGTERM on POSIX systems, and taskkill on Windows systems
# the ${PID} variable can be used in cmdStop, it will be automatically replaced
# with the PID of the running model
cmdStop: docker stop dockertest
# Groups provide advanced controls over model swapping behaviour. Using groups # Groups provide advanced controls over model swapping behaviour. Using groups
# some models can be kept loaded indefinitely, while others are swapped out. # some models can be kept loaded indefinitely, while others are swapped out.
# #
@@ -189,18 +195,13 @@ groups:
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases. - [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest. - [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations. - [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
## Configuration
llama-s
</details> </details>
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) ## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker is the quickest way to try out llama-swap: Docker is the quickest way to try out llama-swap:
``` ```shell
# use CPU inference # use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu $ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
@@ -236,7 +237,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration. Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
``` ```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \ -v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \ -v /path/to/custom/config.yaml:/app/config.yaml \
@@ -251,7 +252,12 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
1. Create a configuration file, see [config.example.yaml](config.example.yaml) 1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Run the binary with `llama-swap --config path/to/config.yaml` 1. Run the binary with `llama-swap --config path/to/config.yaml`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source ### Building from source
@@ -266,7 +272,7 @@ Open the `http://<host>/logs` with your browser to get a web interface with stre
Of course, CLI access is also supported: Of course, CLI access is also supported:
``` ```shell
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'
+20 -24
View File
@@ -5,13 +5,20 @@ healthCheckTimeout: 90
# valid log levels: debug, info (default), warn, error # valid log levels: debug, info (default), warn, error
logLevel: debug logLevel: debug
# creating a coding profile with models for code generation and general questions
groups:
coding:
swap: false
members:
- "qwen"
- "llama"
models: models:
"llama": "llama":
cmd: > cmd: |
models/llama-server-osx models/llama-server-osx
--port 9001 --port ${PORT}
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf -m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:9001
# list of model name aliases this llama.cpp instance can serve # list of model name aliases this llama.cpp instance can serve
aliases: aliases:
@@ -24,17 +31,15 @@ models:
ttl: 5 ttl: 5
"qwen": "qwen":
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9002
aliases: aliases:
- gpt-3.5-turbo - gpt-3.5-turbo
# Embedding example with Nomic # Embedding example with Nomic
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF # https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
"nomic": "nomic":
proxy: http://127.0.0.1:9005 cmd: |
cmd: > models/llama-server-osx --port ${PORT}
models/llama-server-osx --port 9005
-m models/nomic-embed-text-v1.5.Q8_0.gguf -m models/nomic-embed-text-v1.5.Q8_0.gguf
--ctx-size 8192 --ctx-size 8192
--batch-size 8192 --batch-size 8192
@@ -46,19 +51,17 @@ models:
# Reranking example with bge-reranker # Reranking example with bge-reranker
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF # https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
"bge-reranker": "bge-reranker":
proxy: http://127.0.0.1:9006 cmd: |
cmd: > models/llama-server-osx --port ${PORT}
models/llama-server-osx --port 9006
-m models/bge-reranker-v2-m3-Q4_K_M.gguf -m models/bge-reranker-v2-m3-Q4_K_M.gguf
--ctx-size 8192 --ctx-size 8192
--reranking --reranking
# Docker Support (v26.1.4+ required!) # Docker Support (v26.1.4+ required!)
"dockertest": "dockertest":
proxy: "http://127.0.0.1:9790" cmd: |
cmd: >
docker run --name dockertest docker run --name dockertest
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models --init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggerganov/llama.cpp:server ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
@@ -67,8 +70,7 @@ models:
env: env:
- CUDA_VISIBLE_DEVICES=0,1 - CUDA_VISIBLE_DEVICES=0,1
- env1=hello - env1=hello
cmd: build/simple-responder --port 8999 cmd: build/simple-responder --port ${PORT}
proxy: http://127.0.0.1:8999
unlisted: true unlisted: true
# use "none" to skip check. Caution this may cause some requests to fail # use "none" to skip check. Caution this may cause some requests to fail
@@ -83,10 +85,4 @@ models:
"broken_timeout": "broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000 proxy: http://127.0.0.1:9000
unlisted: true unlisted: true
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"
+2 -1
View File
@@ -3,8 +3,8 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
@@ -12,6 +12,7 @@ require (
) )
require ( require (
github.com/billziss-gh/golib v0.2.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
+10 -18
View File
@@ -1,3 +1,5 @@
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@@ -23,6 +29,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
@@ -74,34 +82,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+137 -11
View File
@@ -1,25 +1,34 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"syscall" "syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
) )
var version string = "0" var (
var commit string = "abcd1234" version string = "0"
var date = "unknown" commit string = "abcd1234"
date string = "unknown"
)
func main() { func main() {
// Define a command-line flag for the port // Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name") configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port") listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build") showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
@@ -46,18 +55,135 @@ func main() {
proxyManager := proxy.New(config) proxyManager := proxy.New(config)
// Setup channels for server management
reloadChan := make(chan *proxy.ProxyManager)
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create server with initial handler
srv := &http.Server{
Addr: *listenStr,
Handler: proxyManager,
}
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() { go func() {
<-sigChan if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Println("Shutting down llama-swap") fmt.Printf("Fatal server error: %v\n", err)
proxyManager.Shutdown() close(exitChan)
os.Exit(0) }
}() }()
fmt.Println("llama-swap listening on " + *listenStr) // Handle config reloads and signals
if err := proxyManager.Run(*listenStr); err != nil { go func() {
fmt.Printf("Server error: %v\n", err) currentManager := proxyManager
os.Exit(1) for {
select {
case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
// Now do a full shutdown to clear the process map
currentManager.Shutdown()
currentManager = newManager
srv.Handler = newManager
log.Println("Server handler updated with new config")
case sig := <-sigChan:
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentManager.Shutdown()
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
return
}
}
}()
// Start file watcher if requested
if *watchConfig {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
} else {
go watchConfigFileWithReload(absConfigPath, reloadChan)
}
}
// Wait for exit signal
<-exitChan
}
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
return
}
defer watcher.Close()
err = watcher.Add(configPath)
if err != nil {
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
return
}
log.Printf("Watching config file for changes: %s", configPath)
var debounceTimer *time.Timer
debounceDuration := 2 * time.Second
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
// We only care about writes to the specific config file
if event.Name == configPath && event.Has(fsnotify.Write) {
// Reset or start the debounce timer
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(debounceDuration, func() {
log.Printf("Config file modified: %s, reloading...", event.Name)
// Try up to 3 times with exponential backoff
var newConfig proxy.Config
var err error
for retries := 0; retries < 3; retries++ {
// Load new configuration
newConfig, err = proxy.LoadConfig(configPath)
if err == nil {
break
}
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
if retries < 2 {
time.Sleep(time.Duration(1<<retries) * time.Second)
}
}
if err != nil {
log.Printf("Failed to load new config after retries: %v", err)
return
}
// Create new ProxyManager with new config
newPM := proxy.New(newConfig)
reloadChan <- newPM
log.Println("Config reloaded successfully")
})
}
case err, ok := <-watcher.Errors:
if !ok {
log.Println("File watcher error channel closed.")
return
}
log.Printf("File watcher error: %v", err)
}
} }
} }
+34 -3
View File
@@ -26,6 +26,8 @@ func main() {
silent := flag.Bool("silent", false, "disable all logging") silent := flag.Bool("silent", false, "disable all logging")
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
// Create a new Gin router // Create a new Gin router
@@ -190,6 +192,10 @@ func main() {
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
} }
if !*silent {
fmt.Printf("My PID: %d\n", os.Getpid())
}
go func() { go func() {
log.Printf("simple-responder listening on %s\n", address) log.Printf("simple-responder listening on %s\n", address)
// service connections // service connections
@@ -200,11 +206,36 @@ func main() {
// Wait for interrupt signal to gracefully shutdown the server with // Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds. // a timeout of 5 seconds.
quit := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM // kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT // kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-quit
countSigInt := 0
runloop:
for {
signal := <-sigChan
switch signal {
case syscall.SIGINT:
countSigInt++
if countSigInt > 1 {
break runloop
} else {
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
}
case syscall.SIGTERM:
if *ignoreSigTerm {
log.Println("Ignoring SIGTERM")
} else {
log.Println("Recieved SIGTERM, shutting down")
break runloop
}
default:
break runloop
}
}
log.Println("simple-responder shutting down") log.Println("simple-responder shutting down")
} }
+30 -9
View File
@@ -4,11 +4,12 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"github.com/google/shlex" "github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -16,6 +17,7 @@ const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct { type ModelConfig struct {
Cmd string `yaml:"cmd"` Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"` Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"` Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"` Env []string `yaml:"env"`
@@ -23,6 +25,9 @@ type ModelConfig struct {
UnloadAfter int `yaml:"ttl"` UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"` UseModelName string `yaml:"useModelName"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {
@@ -131,7 +136,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
} }
// iterate over the models and replace any ${PORT} with the next available port
// Get and sort all model IDs first, makes testing more consistent // Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models)) modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models { for modelId := range config.Models {
@@ -139,10 +143,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
sort.Strings(modelIds) // This guarantees stable iteration order sort.Strings(modelIds) // This guarantees stable iteration order
// iterate over the sorted models
nextPort := config.StartPort nextPort := config.StartPort
for _, modelId := range modelIds { for _, modelId := range modelIds {
modelConfig := config.Models[modelId] modelConfig := config.Models[modelId]
// iterate over the models and replace any ${PORT} with the next available port
if strings.Contains(modelConfig.Cmd, "${PORT}") { if strings.Contains(modelConfig.Cmd, "${PORT}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort)) modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
if modelConfig.Proxy == "" { if modelConfig.Proxy == "" {
@@ -156,6 +160,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId) return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
} }
} }
config = AddDefaultGroupToConfig(config) config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups // check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in memberUsage := make(map[string]string) // maps member to group it appears in
@@ -225,14 +230,30 @@ func AddDefaultGroupToConfig(config Config) Config {
} }
func SanitizeCommand(cmdStr string) ([]string, error) { func SanitizeCommand(cmdStr string) ([]string, error) {
// Remove trailing backslashes var cleanedLines []string
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ") for _, line := range strings.Split(cmdStr, "\n") {
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ") trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments // Split the command into arguments
args, err := shlex.Split(cmdStr) var args []string
if err != nil { if runtime.GOOS == "windows" {
return nil, err args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
} }
// Ensure the command is not empty // Ensure the command is not empty
+42
View File
@@ -0,0 +1,42 @@
//go:build !windows
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
# comment 1
--arg3 123 \
# comment 2
--arg4 '"string in string"'
# this will get stripped out as well as the white space above
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
-28
View File
@@ -258,34 +258,6 @@ func TestConfig_FindConfig(t *testing.T) {
assert.Equal(t, ModelConfig{}, modelConfig) assert.Equal(t, ModelConfig{}, modelConfig)
} }
func TestConfig_SanitizeCommand(t *testing.T) {
// Test a command with spaces and newlines
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
--arg2 'single quotes'
-s
--arg3 123 \
--arg4 '"string in string"'
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"--arg2", "single quotes",
"-s",
"--arg3", "123",
"--arg4", `"string in string"`,
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
func TestConfig_AutomaticPortAssignments(t *testing.T) { func TestConfig_AutomaticPortAssignments(t *testing.T) {
t.Run("Default Port Ranges", func(t *testing.T) { t.Run("Default Port Ranges", func(t *testing.T) {
+41
View File
@@ -0,0 +1,41 @@
//go:build windows
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_SanitizeCommand(t *testing.T) {
// does not support single quoted strings like in config_posix_test.go
args, err := SanitizeCommand(`python model1.py \
-a "double quotes" \
-s
--arg3 123 \
# comment 2
--arg4 '"string in string"'
# this will get stripped out as well as the white space above
-c "'single quoted'"
`)
assert.NoError(t, err)
assert.Equal(t, []string{
"python", "model1.py",
"-a", "double quotes",
"-s",
"--arg3", "123",
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
"-c", `'single quoted'`,
}, args)
// Test an empty command
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}
+12 -3
View File
@@ -45,17 +45,26 @@ func TestMain(m *testing.M) {
func getSimpleResponderPath() string { func getSimpleResponderPath() string {
goos := runtime.GOOS goos := runtime.GOOS
goarch := runtime.GOARCH goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
if goos == "windows" {
return filepath.Join("..", "build", "simple-responder.exe")
} else {
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
} }
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { func getTestPort() int {
portMutex.Lock() portMutex.Lock()
defer portMutex.Unlock() defer portMutex.Unlock()
port := nextTestPort port := nextTestPort
nextTestPort++ nextTestPort++
return getTestSimpleResponderConfigPort(expectedMessage, port) return port
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
} }
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
+126 -19
View File
@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -30,6 +31,13 @@ const (
StateShutdown ProcessState = ProcessState("shutdown") StateShutdown ProcessState = ProcessState("shutdown")
) )
type StopStrategy int
const (
StopImmediately StopStrategy = iota
StopWaitForInflightRequest
)
type Process struct { type Process struct {
ID string ID string
config ModelConfig config ModelConfig
@@ -57,10 +65,24 @@ type Process struct {
// for managing shutdown state // for managing shutdown state
shutdownCtx context.Context shutdownCtx context.Context
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
// for managing concurrency limits
concurrencyLimitSemaphore chan struct{}
// stop timeout waiting for graceful shutdown
gracefulStopTimeout time.Duration
// track that this happened
upstreamWasStoppedWithKill bool
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
}
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
@@ -73,6 +95,13 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
state: StateStopped, state: StateStopped,
shutdownCtx: ctx, shutdownCtx: ctx,
shutdownCancel: cancel, shutdownCancel: cancel,
// concurrency limit
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
// stop timeout
gracefulStopTimeout: 5 * time.Second,
upstreamWasStoppedWithKill: false,
} }
} }
@@ -119,7 +148,9 @@ func isValidTransition(from, to ProcessState) bool {
return to == StateStopping return to == StateStopping
case StateStopping: case StateStopping:
return to == StateStopped || to == StateShutdown return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown: case StateFailed:
return to == StateStopping
case StateShutdown:
return false // No transitions allowed from these states return false // No transitions allowed from these states
} }
return false return false
@@ -185,10 +216,19 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
// Capture the exit error for later signaling // Capture the exit error for later signalling
go func() { go func() {
exitErr := p.cmd.Wait() exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr) p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
if p.upstreamWasStoppedWithKill {
p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr)
p.upstreamWasStoppedWithKill = false
return
}
p.cmdWaitChan <- exitErr p.cmdWaitChan <- exitErr
}() }()
@@ -260,9 +300,9 @@ func (p *Process) start() error {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds()) p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
} else { } else {
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err) p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
} }
} }
} }
@@ -301,23 +341,42 @@ func (p *Process) start() error {
} }
} }
// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() { func (p *Process) Stop() {
if !isValidTransition(p.CurrentState(), StateStopping) { if !isValidTransition(p.CurrentState(), StateStopping) {
return return
} }
// wait for any inflight requests before proceeding // wait for any inflight requests before proceeding
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
p.inFlightRequests.Wait() p.inFlightRequests.Wait()
p.proxyLogger.Debugf("<%s> Stopping process", p.ID) p.StopImmediately()
}
// calling Stop() when state is invalid is a no-op // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
if curState, err := p.swapState(StateReady, StateStopping); err != nil { // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return return
} }
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
currentState := p.CurrentState()
if currentState == StateFailed {
if curState, err := p.swapState(StateFailed, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Failed -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return
}
} else {
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return
}
}
// stop the process with a graceful exit timeout // stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second) p.stopCommand(p.gracefulStopTimeout)
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState) p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
@@ -326,10 +385,17 @@ func (p *Process) Stop() {
// Shutdown is called when llama-swap is shutting down. It will give a little bit // Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process // of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down // is in the state of starting, it will cancel it and shut it down. Once a process is in
// the StateShutdown state, it can not be started again.
func (p *Process) Shutdown() { func (p *Process) Shutdown() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
p.shutdownCancel() p.shutdownCancel()
p.stopCommand(5 * time.Second) p.stopCommand(p.gracefulStopTimeout)
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown p.state = StateShutdown
} }
@@ -345,31 +411,64 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
defer cancelTimeout() defer cancelTimeout()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID) p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { // if err := p.terminateProcess(); err != nil {
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err) // p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
// }
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
}
if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return
}
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
stopCmd.Env = p.config.Env
if err := stopCmd.Run(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
return
}
} else {
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
return
}
} }
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID) p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
p.cmd.Process.Kill() p.upstreamWasStoppedWithKill = true
if err := p.cmd.Process.Kill(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
}
case err := <-p.cmdWaitChan: case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK // Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it // because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check // through the health check. There is a possibility that the cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now. // succeeded but that's not a case llama-swap is handling for now.
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID) p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID) p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
} else { } else {
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
} }
@@ -415,6 +514,14 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
select {
case p.concurrencyLimitSemaphore <- struct{}{}:
defer func() { <-p.concurrencyLimitSemaphore }()
default:
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
p.inFlightRequests.Add(1) p.inFlightRequests.Add(1)
defer func() { defer func() {
p.lastRequestHandled = time.Now() p.lastRequestHandled = time.Now()
-9
View File
@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
-14
View File
@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+128
View File
@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"runtime"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -340,3 +341,130 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error()) assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed) assert.Equal(t, process.CurrentState(), StateFailed)
} }
func TestProcess_ConcurrencyLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping long concurrency limit test")
}
expectedMessage := "concurrency_limit_test"
config := getTestSimpleResponderConfig(expectedMessage)
// only allow 1 concurrent request at a time
config.ConcurrencyLimit = 1
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
defer process.Stop()
// launch a goroutine first to take up the semaphore
go func() {
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req1)
assert.Equal(t, http.StatusOK, w.Code)
}()
// let the goroutine start
<-time.After(time.Millisecond * 25)
denied := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
func TestProcess_StopImmediately(t *testing.T) {
expectedMessage := "test_stop_immediate"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) {
expectedMessage := "test_sigkill"
binaryPath := getSimpleResponderPath()
port := getTestPort()
config := ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
defer process.Stop()
// reduce to make testing go faster
process.gracefulStopTimeout = time.Second
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
waitChan := make(chan struct{})
go func() {
// slow, but will get killed by StopImmediate
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
// StatusOK because that was already sent before the kill
assert.Equal(t, http.StatusOK, w.Code)
// unexpected EOF because the kill happened, the "1" is sent before the kill
// then the unexpected EOF is sent after the kill
if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else {
assert.Contains(t, w.Body.String(), "unexpected EOF")
}
close(waitChan)
}()
<-time.After(time.Millisecond)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
// the request should have been interrupted by SIGKILL
<-waitChan
}
func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
}
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}
+7 -6
View File
@@ -76,14 +76,10 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName) return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
} }
func (pg *ProcessGroup) StopProcesses() { func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock() pg.Lock()
defer pg.Unlock() defer pg.Unlock()
pg.stopProcesses()
}
// stopProcesses stops all processes in the group
func (pg *ProcessGroup) stopProcesses() {
if len(pg.processes) == 0 { if len(pg.processes) == 0 {
return return
} }
@@ -94,7 +90,12 @@ func (pg *ProcessGroup) stopProcesses() {
wg.Add(1) wg.Add(1)
go func(process *Process) { go func(process *Process) {
defer wg.Done() defer wg.Done()
process.Stop() switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
}(process) }(process)
} }
wg.Wait() wg.Wait()
+2 -2
View File
@@ -46,7 +46,7 @@ func TestProcessGroup_HasMember(t *testing.T) {
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses() defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"} tests := []string{"model1", "model2"}
@@ -74,7 +74,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses() defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model3", "model4"} tests := []string{"model3", "model4"}
+46 -16
View File
@@ -82,6 +82,11 @@ func New(config Config) *ProxyManager {
pm.processGroups[groupID] = processGroup pm.processGroups[groupID] = processGroup
} }
pm.setupGinEngine()
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// Start timer // Start timer
start := time.Now() start := time.Now()
@@ -192,19 +197,18 @@ func New(config Config) *ProxyManager {
// Disable console color for testing // Disable console color for testing
gin.DisableConsoleColor() gin.DisableConsoleColor()
return pm
} }
func (pm *ProxyManager) Run(addr ...string) error { // ServeHTTP implements http.Handler interface
return pm.ginEngine.Run(addr...) func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r) pm.ginEngine.ServeHTTP(w, r)
} }
func (pm *ProxyManager) StopProcesses() { // StopProcesses acquires a lock and stops all running upstream processes.
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
@@ -214,15 +218,14 @@ func (pm *ProxyManager) StopProcesses() {
wg.Add(1) wg.Add(1)
go func(processGroup *ProcessGroup) { go func(processGroup *ProcessGroup) {
defer wg.Done() defer wg.Done()
processGroup.stopProcesses() processGroup.StopProcesses(strategy)
}(processGroup) }(processGroup)
} }
wg.Wait() wg.Wait()
} }
// Shutdown is called to shutdown all upstream processes // Shutdown stops all processes managed by this ProxyManager
// when llama-swap is shutting down.
func (pm *ProxyManager) Shutdown() { func (pm *ProxyManager) Shutdown() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
@@ -257,7 +260,7 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups { for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent { if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses() otherGroup.StopProcesses(StopWaitForInflightRequest)
} }
} }
} }
@@ -316,7 +319,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
func (pm *ProxyManager) upstreamIndex(c *gin.Context) { func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>") html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><a href=\"/unload\">Unload all models</a><ul>")
// Extract keys and sort them // Extract keys and sort them
var modelIDs []string var modelIDs []string
@@ -331,7 +334,33 @@ func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
// Iterate over sorted keys // Iterate over sorted keys
for _, modelID := range modelIDs { for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID)) // Get process state
processGroup := pm.findGroupByModelName(modelID)
var state string
if processGroup != nil {
process := processGroup.processes[modelID]
if process != nil {
var stateStr string
switch process.CurrentState() {
case StateReady:
stateStr = "Ready"
case StateStarting:
stateStr = "Starting"
case StateStopping:
stateStr = "Stopping"
case StateFailed:
stateStr = "Failed"
case StateShutdown:
stateStr = "Shutdown"
case StateStopped:
stateStr = "Stopped"
default:
stateStr = "Unknown"
}
state = stateStr
}
}
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a> - %s</li>", modelID, modelID, state))
} }
html.WriteString("</ul></body></html>") html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html") c.Header("Content-Type", "text/html")
@@ -371,7 +400,8 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
// dechunk it as we already have all the body bytes see issue #11 // dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding") c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
@@ -501,7 +531,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
} }
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses() pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK") c.String(http.StatusOK, "OK")
} }
+35 -32
View File
@@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
) )
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
@@ -27,14 +28,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
for _, modelName := range []string{"model1", "model2"} { for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
} }
@@ -63,7 +64,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"} tests := []string{"model1", "model2"}
for _, requestedModel := range tests { for _, requestedModel := range tests {
@@ -72,10 +73,9 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), requestedModel)
}) })
} }
@@ -106,7 +106,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
// make requests to load all models, loading model1 should not affect model2 // make requests to load all models, loading model1 should not affect model2
tests := []string{"model2", "model1"} tests := []string{"model2", "model1"}
@@ -115,7 +115,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), requestedModel)
} }
@@ -142,7 +142,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
results := map[string]string{} results := map[string]string{}
@@ -158,14 +158,13 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key) t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
} }
mu.Lock() mu.Lock()
var response map[string]string var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"] results[key] = response["responseMessage"]
@@ -202,7 +201,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Call the listModelsHandler // Call the listModelsHandler
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
// Check the response status code // Check the response status code
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -292,7 +291,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up // send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
}(modelName) }(modelName)
@@ -318,12 +317,12 @@ func TestProxyManager_Unload(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil) req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
@@ -334,7 +333,6 @@ func TestProxyManager_Unload(t *testing.T) {
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration
config := AddDefaultGroupToConfig(Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
@@ -342,7 +340,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "debug", LogLevel: "warn",
}) })
// Define a helper struct to parse the JSON response. // Define a helper struct to parse the JSON response.
@@ -355,12 +353,12 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Create proxy once for all tests // Create proxy once for all tests
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("no models loaded", func(t *testing.T) { t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil) req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -378,13 +376,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
reqBody := `{"model":"model1"}` reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint. // Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil) req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
var response RunningResponse var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
@@ -410,7 +408,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
// Create a buffer with multipart form data // Create a buffer with multipart form data
var b bytes.Buffer var b bytes.Buffer
@@ -436,7 +434,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req) proxy.ServeHTTP(rec, req)
// Verify the response // Verify the response
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
@@ -451,7 +449,6 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
// Test useModelName in configuration sends overrides what is sent to upstream // Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) { func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel" upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName) modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName modelConfig.UseModelName = upstreamModelName
@@ -464,7 +461,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
requestedModel := "model1" requestedModel := "model1"
@@ -473,9 +470,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName) assert.Contains(t, w.Body.String(), upstreamModelName)
// make sure the content length was set correctly
// simple-responder will return the content length it got in the response
body := w.Body.Bytes()
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
}) })
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) { t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
@@ -500,7 +503,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req) proxy.ServeHTTP(rec, req)
// Verify the response // Verify the response
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
@@ -560,7 +563,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders { for k, v := range tt.requestHeaders {
@@ -568,7 +571,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code) assert.Equal(t, tt.expectedStatus, w.Code)
@@ -589,10 +592,10 @@ func TestProxyManager_Upstream(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil) req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req) proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, "model1", rec.Body.String())
} }
@@ -607,13 +610,13 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
}) })
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
+189
View File
@@ -0,0 +1,189 @@
#!/bin/sh
# This script installs llama-swap on Linux.
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
status() { echo ">>> $*" >&2; }
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
warning() { echo "${red}WARNING:${plain} $*"; }
available() { command -v $1 >/dev/null; }
require() {
local MISSING=''
for TOOL in $*; do
if ! available $TOOL; then
MISSING="$MISSING $TOOL"
fi
done
echo $MISSING
}
SUDO=
if [ "$(id -u)" -ne 0 ]; then
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
NEEDS=$(require curl tee jq tar)
if [ -n "$NEEDS" ]; then
status "ERROR: The following tools are required but missing:"
for NEED in $NEEDS; do
echo " - $NEED"
done
exit 1
fi
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) error "Unsupported architecture: $ARCH" ;;
esac
IS_WSL2=false
KERN=$(uname -r)
case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;;
esac
download_binary() {
ASSET_NAME="linux_$ARCH"
# Fetch the latest release info and extract the matching asset URL
DL_URL=$(curl -s "https://api.github.com/repos/mostlygeek/llama-swap/releases/latest" | \
jq -r --arg name "$ASSET_NAME" \
'.assets[] | select(.name | contains($name)) | .browser_download_url')
# Check if a URL was successfully extracted
if [ -z "$DL_URL" ]; then
error "No matching asset found with name containing '$ASSET_NAME'."
fi
status "Downloading Linux $ARCH binary"
curl -s -L "$DL_URL" | $SUDO tar -xzf - -C /usr/local/bin llama-swap
}
download_binary
configure_systemd() {
if ! id llama-swap >/dev/null 2>&1; then
status "Creating llama-swap user..."
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
fi
if getent group render >/dev/null 2>&1; then
status "Adding llama-swap user to render group..."
$SUDO usermod -a -G render llama-swap
fi
if getent group video >/dev/null 2>&1; then
status "Adding llama-swap user to video group..."
$SUDO usermod -a -G video llama-swap
fi
if getent group docker >/dev/null 2>&1; then
status "Adding llama-swap user to docker group..."
$SUDO usermod -a -G docker llama-swap
fi
status "Adding current user to llama-swap group..."
$SUDO usermod -a -G llama-swap $(whoami)
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
status "Creating default config.yaml..."
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
# default 15s likely to fail for default models due to downloading models
healthCheckTimeout: 60
models:
"qwen2.5":
cmd: |
docker run
--rm
-p \${PORT}:8080
--name qwen2.5
ghcr.io/ggml-org/llama.cpp:server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
cmdStop: docker stop qwen2.5
"smollm2":
cmd: |
docker run
--rm
-p \${PORT}:8080
--name smollm2
ghcr.io/ggml-org/llama.cpp:server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
cmdStop: docker stop smollm2
EOF
fi
status "Creating llama-swap systemd service..."
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
[Unit]
Description=llama-swap
After=network.target
[Service]
User=llama-swap
Group=llama-swap
# set this to match your environment
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config
Restart=on-failure
RestartSec=3
StartLimitBurst=3
StartLimitInterval=30
[Install]
WantedBy=multi-user.target
EOF
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
case $SYSTEMCTL_RUNNING in
running|degraded)
status "Enabling and starting llama-swap service..."
$SUDO systemctl daemon-reload
$SUDO systemctl enable llama-swap
start_service() { $SUDO systemctl restart llama-swap; }
trap start_service EXIT
;;
*)
warning "systemd is not running"
if [ "$IS_WSL2" = true ]; then
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
fi
;;
esac
}
if available systemctl; then
configure_systemd
fi
install_success() {
status 'The llama-swap API is now available at 127.0.0.1:8080.'
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
status 'Install complete.'
}
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
exit 0
fi
install_success
+68
View File
@@ -0,0 +1,68 @@
#!/bin/sh
# This script uninstalls llama-swap on Linux.
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
status() { echo ">>> $*" >&2; }
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
warning() { echo "${red}WARNING:${plain} $*"; }
available() { command -v $1 >/dev/null; }
SUDO=
if [ "$(id -u)" -ne 0 ]; then
if ! available sudo; then
error "This script requires superuser permissions. Please re-run as root."
fi
SUDO="sudo"
fi
configure_systemd() {
status "Stopping llama-swap service..."
$SUDO systemctl stop llama-swap
status "Disabling llama-swap service..."
$SUDO systemctl disable llama-swap
}
if available systemctl; then
configure_systemd
fi
if available llama-swap; then
status "Removing llama-swap binary..."
$SUDO rm $(which llama-swap)
fi
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
while true; do
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
read answer
case "$answer" in
[Yy]* )
$SUDO rm -r /usr/share/llama-swap
break
;;
[Nn]* | "" )
break
;;
* )
echo "Invalid input. Please enter y or n."
;;
esac
done
fi
if id llama-swap >/dev/null 2>&1; then
status "Removing llama-swap user..."
$SUDO userdel llama-swap
fi
if getent group llama-swap >/dev/null 2>&1; then
status "Removing llama-swap group..."
$SUDO groupdel llama-swap
fi