Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 | |||
| 9dc4bcb46c | |||
| cb876c143b | |||
| bc652709a5 | |||
| 9548931258 | |||
| 5c5a5da664 | |||
| aa9ef59aa5 | |||
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 | |||
| 5fad24c16f | |||
| 8404244fab | |||
| 712cd01081 | |||
| 1f7aa359b1 | |||
| b138d6cf25 | |||
| fb7c808082 | |||
| a7e640b0f7 |
@@ -0,0 +1,15 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Something is not working as expected...
|
||||
title: ''
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Expected behaviour**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Operating system and version**
|
||||
|
||||
- OS: (linux, osx, windows, freebsd, etc)
|
||||
- GPUs: (list architecture)
|
||||
|
||||
**My Configuration**
|
||||
|
||||
```yaml
|
||||
# copy / paste your configuration here
|
||||
```
|
||||
|
||||
**Proxy Logs**
|
||||
|
||||
```
|
||||
# copy / paste from /logs
|
||||
```
|
||||
|
||||
**Upstream Logs**
|
||||
|
||||
```
|
||||
# copy/paste from /logs
|
||||
```
|
||||
@@ -13,11 +13,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -15,7 +15,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
#platform: [intel, cuda, vulkan, cpu, musa]
|
||||
platform: [cuda, vulkan, cpu, musa]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
name: Windows CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: make simple-responder-windows
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
run: make test-all
|
||||
@@ -1,6 +1,4 @@
|
||||
# This workflow will build a golang project
|
||||
|
||||
name: CI
|
||||
name: Linux CI
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -24,9 +22,26 @@ jobs:
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -20,10 +20,10 @@ clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -short -v ./proxy
|
||||
go test -short -v -count=1 ./proxy
|
||||
|
||||
test-all:
|
||||
go test -v ./proxy
|
||||
go test -v -count=1 ./proxy
|
||||
|
||||
# Build OSX binary
|
||||
mac:
|
||||
@@ -46,6 +46,10 @@ simple-responder:
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
|
||||
simple-responder-windows:
|
||||
@echo "Building simple responder for windows"
|
||||
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
mkdir -p $(BUILD_DIR)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
|
||||
## Features:
|
||||
|
||||
@@ -26,7 +26,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Docker and Podman support
|
||||
@@ -36,124 +36,49 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap's configuration is purposefully simple.
|
||||
llama-swap's configuration is purposefully simple:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
cmd: |
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
--port ${PORT}
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
cmd: |
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
--port ${PORT}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>But also very powerful ...</summary>
|
||||
.. but also supports many advanced features:
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
- `groups` to run multiple models at once
|
||||
- `macros` for reusable snippets
|
||||
- `ttl` to automatically unload models
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` variables to pass custom environment to inference servers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `healthCheckTimeout` to control model startup wait times
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||
|
||||
# Valid log levels: debug, info (default), warn, error
|
||||
logLevel: info
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||
# until the model is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# automatically unload the model after this many seconds
|
||||
# ttl values must be a value greater than 0
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
# `useModelName` overrides the model name in the request
|
||||
# and sends a specific name to the upstream server
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
unlisted: true
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
cmd: >
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# profiles eliminates swapping by running multiple models at the same time
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
```
|
||||
|
||||
### Use Case Examples
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||
|
||||
## Configuration
|
||||
|
||||
llama-s
|
||||
|
||||
</details>
|
||||
Check the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki for all options.
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
|
||||
```
|
||||
# use CPU inference
|
||||
```shell
|
||||
# use CPU inference comes with the example config above
|
||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||
|
||||
|
||||
# qwen2.5 0.5B
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -161,7 +86,6 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
|
||||
|
||||
# SmolLM2 135M
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -171,7 +95,7 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are nightly ...</summary>
|
||||
<summary>Docker images are built nightly for cuda, intel, vulcan, etc ...</summary>
|
||||
|
||||
They include:
|
||||
|
||||
@@ -185,7 +109,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
|
||||
|
||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||
|
||||
```
|
||||
```shell
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
@@ -200,7 +124,12 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`.
|
||||
Available flags:
|
||||
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||
- `--version`: Show version information and exit.
|
||||
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
||||
|
||||
### Building from source
|
||||
|
||||
@@ -215,7 +144,7 @@ Open the `http://<host>/logs` with your browser to get a web interface with stre
|
||||
|
||||
Of course, CLI access is also supported:
|
||||
|
||||
```
|
||||
```shell
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
@@ -241,36 +170,6 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Systemd Unit Files
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
User=nobody
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Star History
|
||||
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
||||
</picture>
|
||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||
|
||||
+21
-25
@@ -3,15 +3,22 @@
|
||||
healthCheckTimeout: 90
|
||||
|
||||
# valid log levels: debug, info (default), warn, error
|
||||
logLevel: info
|
||||
logLevel: debug
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
groups:
|
||||
coding:
|
||||
swap: false
|
||||
members:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
cmd: |
|
||||
models/llama-server-osx
|
||||
--port 9001
|
||||
--port ${PORT}
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
proxy: http://127.0.0.1:9001
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
@@ -24,17 +31,15 @@ models:
|
||||
ttl: 5
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9002
|
||||
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
- gpt-3.5-turbo
|
||||
|
||||
# Embedding example with Nomic
|
||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||
"nomic":
|
||||
proxy: http://127.0.0.1:9005
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9005
|
||||
cmd: |
|
||||
models/llama-server-osx --port ${PORT}
|
||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||
--ctx-size 8192
|
||||
--batch-size 8192
|
||||
@@ -46,19 +51,17 @@ models:
|
||||
# Reranking example with bge-reranker
|
||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||
"bge-reranker":
|
||||
proxy: http://127.0.0.1:9006
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9006
|
||||
cmd: |
|
||||
models/llama-server-osx --port ${PORT}
|
||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
# Docker Support (v26.1.4+ required!)
|
||||
"dockertest":
|
||||
proxy: "http://127.0.0.1:9790"
|
||||
cmd: >
|
||||
cmd: |
|
||||
docker run --name dockertest
|
||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggerganov/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
@@ -67,8 +70,7 @@ models:
|
||||
env:
|
||||
- CUDA_VISIBLE_DEVICES=0,1
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
proxy: http://127.0.0.1:8999
|
||||
cmd: build/simple-responder --port ${PORT}
|
||||
unlisted: true
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
@@ -83,10 +85,4 @@ models:
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
unlisted: true
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
unlisted: true
|
||||
@@ -0,0 +1,153 @@
|
||||
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
|
||||
|
||||
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
|
||||
|
||||
## Here's what you you need:
|
||||
|
||||
- aider - [installation docs](https://aider.chat/docs/install.html)
|
||||
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
|
||||
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
|
||||
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
|
||||
- 24GB VRAM video card
|
||||
|
||||
## Running aider
|
||||
|
||||
The goal is getting this command line to work:
|
||||
|
||||
```sh
|
||||
aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/QwQ \
|
||||
--editor-model openai/qwen-coder-32B \
|
||||
--model-settings-file aider.model.settings.yml \
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1" \
|
||||
```
|
||||
|
||||
Set `--openai-api-base` to the IP and port where your llama-swap is running.
|
||||
|
||||
## Create an aider model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
|
||||
#
|
||||
# !!! important: model names must match llama-swap configuration names !!!
|
||||
#
|
||||
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
```
|
||||
|
||||
## llama-swap configuration
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
-ngl 99
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
|
||||
"QwQ":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
|
||||
--min-p 0.01 --top-k 40 --top-p 0.95
|
||||
-ngl 99
|
||||
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
## Advanced, Dual GPU Configuration
|
||||
|
||||
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
|
||||
|
||||
In llama-swap's configuration file:
|
||||
|
||||
1. add a `profiles` section with `aider` as the profile name
|
||||
2. using the `env` field to specify the GPU IDs for each model
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# Add a profile for aider
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: /path/to/llama-server ...
|
||||
|
||||
"QwQ":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: /path/to/llama-server ...
|
||||
```
|
||||
|
||||
Append the profile tag, `aider:`, to the model names in the model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
- name: "openai/aider:QwQ"
|
||||
weak_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
```
|
||||
|
||||
Run aider with:
|
||||
|
||||
```sh
|
||||
$ aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/aider:QwQ \
|
||||
--editor-model openai/aider:qwen-coder-32B \
|
||||
--config aider.conf.yml \
|
||||
--model-settings-file aider.model.settings.yml
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1"
|
||||
```
|
||||
@@ -0,0 +1,28 @@
|
||||
# this makes use of llama-swap's profile feature to
|
||||
# keep the architect and editor models in VRAM on different GPUs
|
||||
|
||||
- name: "openai/aider:QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
@@ -0,0 +1,26 @@
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
healthCheckTimeout: 300
|
||||
logLevel: debug
|
||||
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
aliases:
|
||||
- coder
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--ctx-size-draft 16000
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
|
||||
-ngl 99 -ngld 99
|
||||
--draft-max 16 --draft-min 4 --draft-p-min 0.4
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
"QwQ":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6
|
||||
--repeat-penalty 1.1
|
||||
--dry-multiplier 0.5
|
||||
--min-p 0.01
|
||||
--top-k 40
|
||||
--top-p 0.95
|
||||
-ngl 99 -ngld 99
|
||||
@@ -3,8 +3,8 @@ module github.com/mostlygeek/llama-swap
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
@@ -12,6 +12,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
@@ -37,7 +38,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.0 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
@@ -23,6 +29,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
@@ -74,32 +82,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
BIN
Binary file not shown.
|
After Width: | Height: | Size: 351 KiB |
+141
-11
@@ -1,25 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
)
|
||||
|
||||
var version string = "0"
|
||||
var commit string = "abcd1234"
|
||||
var date = "unknown"
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
@@ -34,6 +43,10 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(config.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
@@ -42,18 +55,135 @@ func main() {
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
|
||||
// Setup channels for server management
|
||||
reloadChan := make(chan *proxy.ProxyManager)
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create server with initial handler
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
Handler: proxyManager,
|
||||
}
|
||||
|
||||
// Start server
|
||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("Shutting down llama-swap")
|
||||
proxyManager.Shutdown()
|
||||
os.Exit(0)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("Fatal server error: %v\n", err)
|
||||
close(exitChan)
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
os.Exit(1)
|
||||
// Handle config reloads and signals
|
||||
go func() {
|
||||
currentManager := proxyManager
|
||||
for {
|
||||
select {
|
||||
case newManager := <-reloadChan:
|
||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
||||
// Now do a full shutdown to clear the process map
|
||||
currentManager.Shutdown()
|
||||
currentManager = newManager
|
||||
srv.Handler = newManager
|
||||
log.Println("Server handler updated with new config")
|
||||
case sig := <-sigChan:
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentManager.Shutdown()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start file watcher if requested
|
||||
if *watchConfig {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
||||
} else {
|
||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
|
||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
err = watcher.Add(configPath)
|
||||
if err != nil {
|
||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Watching config file for changes: %s", configPath)
|
||||
|
||||
var debounceTimer *time.Timer
|
||||
debounceDuration := 2 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// We only care about writes to the specific config file
|
||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
||||
// Reset or start the debounce timer
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
||||
|
||||
// Try up to 3 times with exponential backoff
|
||||
var newConfig proxy.Config
|
||||
var err error
|
||||
for retries := 0; retries < 3; retries++ {
|
||||
// Load new configuration
|
||||
newConfig, err = proxy.LoadConfig(configPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
||||
if retries < 2 {
|
||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Failed to load new config after retries: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new ProxyManager with new config
|
||||
newPM := proxy.New(newConfig)
|
||||
reloadChan <- newPM
|
||||
log.Println("Config reloaded successfully")
|
||||
})
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
log.Println("File watcher error channel closed.")
|
||||
return
|
||||
}
|
||||
log.Printf("File watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
**
|
||||
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||
|
||||
- process is killed externally, what happens with cmd.Wait() *
|
||||
✔︎ it returns. catches crashes.*
|
||||
|
||||
- process ignores SIGTERM*
|
||||
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||
|
||||
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||
x they stick around. have to be manually killed.*
|
||||
|
||||
- .WithTimeout()'s cancel is called *
|
||||
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||
|
||||
- parent receives SIGINT/SIGTERM, what happens
|
||||
✔︎ waits for child process to exit, then exits gracefully.
|
||||
*/
|
||||
func main() {
|
||||
|
||||
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"../../build/simple-responder_darwin_arm64",
|
||||
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// set a wait delay before signing sig kill
|
||||
cmd.WaitDelay = 500 * time.Millisecond
|
||||
cmd.Cancel = func() error {
|
||||
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
//return nil
|
||||
|
||||
// this error is returned by cmd.Wait(), and can be used to
|
||||
// single an error when the process couldn't be normally terminated
|
||||
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||
// as WaitDelay timing out will override the any error set here.
|
||||
//
|
||||
// test by enabling/disabling -ignore-sig-term on the process
|
||||
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||
// without it, it will show the "new error from cancel"
|
||||
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Error starting process:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||
// this program to eventually exit gracefully.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
signal := <-sigChan
|
||||
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||
} else {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||
}
|
||||
fmt.Println("✔︎ Child process exited, Done.")
|
||||
}
|
||||
@@ -26,6 +26,8 @@ func main() {
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Create a new Gin router
|
||||
@@ -33,14 +35,17 @@ func main() {
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
@@ -63,8 +68,11 @@ func main() {
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
// issue #41
|
||||
@@ -104,6 +112,10 @@ func main() {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
|
||||
// expose some header values for testing
|
||||
"h_content_type": c.GetHeader("Content-Type"),
|
||||
"h_content_length": c.GetHeader("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
@@ -180,6 +192,10 @@ func main() {
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
if !*silent {
|
||||
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
@@ -190,11 +206,36 @@ func main() {
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
quit := make(chan os.Signal, 1)
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
countSigInt := 0
|
||||
|
||||
runloop:
|
||||
for {
|
||||
signal := <-sigChan
|
||||
switch signal {
|
||||
case syscall.SIGINT:
|
||||
countSigInt++
|
||||
if countSigInt > 1 {
|
||||
break runloop
|
||||
} else {
|
||||
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||
}
|
||||
case syscall.SIGTERM:
|
||||
if *ignoreSigTerm {
|
||||
log.Println("Ignoring SIGTERM")
|
||||
} else {
|
||||
log.Println("Received SIGTERM, shutting down")
|
||||
break runloop
|
||||
}
|
||||
default:
|
||||
break runloop
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
|
||||
+244
-14
@@ -2,15 +2,23 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
@@ -18,21 +26,56 @@ type ModelConfig struct {
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros map[string]string `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -53,42 +96,229 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
var config Config
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
if config.HealthCheckTimeout == 0 {
|
||||
// this high default timeout helps avoid failing health checks
|
||||
// for configurations that wait for docker or have slower startup
|
||||
config.HealthCheckTimeout = 120
|
||||
} else if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// set default port ranges
|
||||
if config.StartPort == 0 {
|
||||
// default to 5800
|
||||
config.StartPort = 5800
|
||||
} else if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
for macroName, macroValue := range config.Macros {
|
||||
if len(macroName) >= 64 {
|
||||
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
|
||||
}
|
||||
if !macroNameRegex.MatchString(macroName) {
|
||||
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
|
||||
}
|
||||
if len(macroValue) >= 1024 {
|
||||
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
|
||||
}
|
||||
switch macroName {
|
||||
case "PORT":
|
||||
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range config.Macros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||
}
|
||||
|
||||
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
if modelConfig.Proxy == "" {
|
||||
modelConfig.Proxy = "http://localhost:${PORT}"
|
||||
}
|
||||
|
||||
nextPortStr := strconv.Itoa(nextPort)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
||||
nextPort++
|
||||
} else if modelConfig.Proxy == "" {
|
||||
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if _, exists := config.Macros[macroName]; !exists {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName, _ := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
// Remove trailing backslashes
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
args, err := shlex.Split(cmdStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
# comment 1
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
+292
-27
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,6 +19,8 @@ func TestConfig_Load(t *testing.T) {
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
@@ -30,16 +33,37 @@ models:
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
@@ -52,7 +76,11 @@ profiles:
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
expected := Config{
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -62,12 +90,24 @@ profiles:
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
@@ -77,6 +117,25 @@ profiles:
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,6 +146,63 @@ profiles:
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
|
||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
group2:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
`
|
||||
// Load the config and verify
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
|
||||
// a Contains as order of the map is not guaranteed
|
||||
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
||||
}
|
||||
|
||||
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- m1
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
checkEndpoint: "/"
|
||||
aliases:
|
||||
- m1
|
||||
- m2
|
||||
`
|
||||
// Load the config and verify
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
|
||||
// this is a contains because it could be `model1` or `model2` depending on the order
|
||||
// go decided on the order of the map
|
||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
@@ -147,30 +263,179 @@ func TestConfig_FindConfig(t *testing.T) {
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
--arg3 123 \
|
||||
--arg4 '"string in string"'
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||
content := ``
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
})
|
||||
t.Run("User specific port ranges", func(t *testing.T) {
|
||||
content := `startPort: 1000`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1000, config.StartPort)
|
||||
})
|
||||
|
||||
t.Run("Invalid start port", func(t *testing.T) {
|
||||
content := `startPort: abcd`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("start port must be greater than 1", func(t *testing.T) {
|
||||
content := `startPort: -99`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Automatic port assignments", func(t *testing.T) {
|
||||
content := `
|
||||
startPort: 5800
|
||||
models:
|
||||
model1:
|
||||
cmd: svr --port ${PORT}
|
||||
model2:
|
||||
cmd: svr --port ${PORT}
|
||||
proxy: "http://172.11.22.33:${PORT}"
|
||||
model3:
|
||||
cmd: svr --port 1999
|
||||
proxy: "http://1.2.3.4:1999"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
||||
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
||||
|
||||
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
||||
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
||||
|
||||
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
||||
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: svr --port 111
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_MacroReplacement(t *testing.T) {
|
||||
content := `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
argOne: "--arg1"
|
||||
argTwo: "--arg2"
|
||||
autoPort: "--port ${PORT}"
|
||||
|
||||
models:
|
||||
model1:
|
||||
cmd: |
|
||||
${svr-path} ${argTwo}
|
||||
# the automatic ${PORT} is replaced
|
||||
${autoPort}
|
||||
${argOne}
|
||||
--arg3 three
|
||||
cmdStop: |
|
||||
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
|
||||
|
||||
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||
}
|
||||
|
||||
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
name: "unknown macro in cmd",
|
||||
field: "cmd",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: |
|
||||
${svr-path} --port ${PORT}
|
||||
${unknownMacro}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in cmdStop",
|
||||
field: "cmdStop",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
cmdStop: "kill ${unknownMacro}"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in proxy",
|
||||
field: "proxy",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
proxy: "http://localhost:${unknownMacro}"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in checkEndpoint",
|
||||
field: "checkEndpoint",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||
//t.Log(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// does not support single quoted strings like in config_posix_test.go
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
|
||||
-a "double quotes" \
|
||||
-s
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
+24
-3
@@ -14,6 +14,7 @@ import (
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -26,6 +27,17 @@ func TestMain(m *testing.M) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
@@ -33,17 +45,26 @@ func TestMain(m *testing.M) {
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||
return port
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
|
||||
@@ -170,6 +170,7 @@
|
||||
|
||||
this.eventSource.onmessage = (event) => {
|
||||
this.logData += event.data;
|
||||
this.logData = this.logData.slice(-1024 * 100);
|
||||
this.render();
|
||||
};
|
||||
|
||||
|
||||
+217
-95
@@ -8,6 +8,8 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -22,18 +24,28 @@ const (
|
||||
StateReady ProcessState = ProcessState("ready")
|
||||
StateStopping ProcessState = ProcessState("stopping")
|
||||
|
||||
// failed a health check on start and will not be recovered
|
||||
StateFailed ProcessState = ProcessState("failed")
|
||||
|
||||
// process is shutdown and will not be restarted
|
||||
StateShutdown ProcessState = ProcessState("shutdown")
|
||||
)
|
||||
|
||||
type StopStrategy int
|
||||
|
||||
const (
|
||||
StopImmediately StopStrategy = iota
|
||||
StopWaitForInflightRequest
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
|
||||
// PR #155 called to cancel the upstream process
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
cmdWaitChan chan struct{}
|
||||
|
||||
processLogger *LogMonitor
|
||||
proxyLogger *LogMonitor
|
||||
|
||||
@@ -50,24 +62,40 @@ type Process struct {
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
|
||||
// for managing shutdown state
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
// for managing concurrency limits
|
||||
concurrencyLimitSemaphore chan struct{}
|
||||
|
||||
// used for testing to override the default value
|
||||
gracefulStopTimeout time.Duration
|
||||
|
||||
// track the number of failed starts
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
healthCheckTimeout: healthCheckTimeout,
|
||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||
state: StateStopped,
|
||||
shutdownCtx: ctx,
|
||||
shutdownCancel: cancel,
|
||||
|
||||
// concurrency limit
|
||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||
|
||||
// To be removed when migration over exec.CommandContext is complete
|
||||
// stop timeout
|
||||
gracefulStopTimeout: 10 * time.Second,
|
||||
cmdWaitChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,16 +117,17 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
||||
defer p.stateMutex.Unlock()
|
||||
|
||||
if p.state != expectedState {
|
||||
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||
return p.state, ErrExpectedStateMismatch
|
||||
}
|
||||
|
||||
if !isValidTransition(p.state, newState) {
|
||||
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
|
||||
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||
return p.state, ErrInvalidStateTransition
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
|
||||
p.state = newState
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
return p.state, nil
|
||||
}
|
||||
|
||||
@@ -108,12 +137,12 @@ func isValidTransition(from, to ProcessState) bool {
|
||||
case StateStopped:
|
||||
return to == StateStarting
|
||||
case StateStarting:
|
||||
return to == StateReady || to == StateFailed || to == StateStopping
|
||||
return to == StateReady || to == StateStopping || to == StateStopped
|
||||
case StateReady:
|
||||
return to == StateStopping
|
||||
case StateStopping:
|
||||
return to == StateStopped || to == StateShutdown
|
||||
case StateFailed, StateShutdown:
|
||||
case StateShutdown:
|
||||
return false // No transitions allowed from these states
|
||||
}
|
||||
return false
|
||||
@@ -160,17 +189,24 @@ func (p *Process) start() error {
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
defer p.waitStarting.Done()
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.processLogger
|
||||
p.cmd.Stderr = p.processLogger
|
||||
p.cmd.Env = p.config.Env
|
||||
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
err = p.cmd.Start()
|
||||
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.state = StateStopped // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
err, curState, swapErr,
|
||||
@@ -179,6 +215,9 @@ func (p *Process) start() error {
|
||||
return fmt.Errorf("start() failed: %v", err)
|
||||
}
|
||||
|
||||
// Capture the exit error for later signalling
|
||||
go p.waitForCmd()
|
||||
|
||||
// One of three things can happen at this stage:
|
||||
// 1. The command exits unexpectedly
|
||||
// 2. The health check fails
|
||||
@@ -204,40 +243,32 @@ func (p *Process) start() error {
|
||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
||||
context.Background(),
|
||||
checkStartTime.Add(maxDuration),
|
||||
)
|
||||
defer cancelHealthCheck()
|
||||
|
||||
loop:
|
||||
// Ready Check loop
|
||||
for {
|
||||
select {
|
||||
case <-checkDeadline.Done():
|
||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
||||
} else {
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
currentState := p.CurrentState()
|
||||
if currentState != StateStarting {
|
||||
if currentState == StateStopped {
|
||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||
}
|
||||
case <-p.shutdownCtx.Done():
|
||||
return errors.New("health check interrupted due to shutdown")
|
||||
default:
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("Health check passed on %s", healthURL)
|
||||
cancelHealthCheck()
|
||||
break loop
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
endTime, _ := checkDeadline.Deadline()
|
||||
ttl := time.Until(endTime)
|
||||
p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if time.Since(checkStartTime) > maxDuration {
|
||||
p.stopCommand()
|
||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||
}
|
||||
|
||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||
break
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||
}
|
||||
}
|
||||
<-time.After(p.healthCheckLoopInterval)
|
||||
}
|
||||
}
|
||||
@@ -257,8 +288,7 @@ func (p *Process) start() error {
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
|
||||
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
@@ -269,82 +299,71 @@ func (p *Process) start() error {
|
||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||
} else {
|
||||
p.failedStartCount = 0
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop will wait for inflight requests to complete before stopping the process.
|
||||
func (p *Process) Stop() {
|
||||
// wait for any inflight requests before proceeding
|
||||
p.inFlightRequests.Wait()
|
||||
|
||||
// calling Stop() when state is invalid is a no-op
|
||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
// stop the process with a graceful exit timeout
|
||||
p.stopCommand(5 * time.Second)
|
||||
// wait for any inflight requests before proceeding
|
||||
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||
p.inFlightRequests.Wait()
|
||||
p.StopImmediately()
|
||||
}
|
||||
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
|
||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||
func (p *Process) StopImmediately() {
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
}
|
||||
|
||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||
// of time for any inflight requests to complete before shutting down. If the Process
|
||||
// is in the state of starting, it will cancel it and shut it down
|
||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||
// the StateShutdown state, it can not be started again.
|
||||
func (p *Process) Shutdown() {
|
||||
p.shutdownCancel()
|
||||
p.stopCommand(5 * time.Second)
|
||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||
return
|
||||
}
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.state = StateShutdown
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||
defer cancelTimeout()
|
||||
|
||||
sigtermNormal := make(chan error, 1)
|
||||
go func() {
|
||||
sigtermNormal <- p.cmd.Wait()
|
||||
func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
}()
|
||||
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
|
||||
if p.cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.terminateProcess(); err != nil {
|
||||
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sigtermTimeout.Done():
|
||||
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
|
||||
p.cmd.Process.Kill()
|
||||
case err := <-sigtermNormal:
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
p.cancelUpstream()
|
||||
<-p.cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
@@ -369,14 +388,24 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
// prevent new requests from being made while stopping or irrecoverable
|
||||
currentState := p.CurrentState()
|
||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
||||
if currentState == StateShutdown || currentState == StateStopping {
|
||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||
default:
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
@@ -385,11 +414,13 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
@@ -400,6 +431,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
|
||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||
if err == nil {
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
@@ -433,4 +470,89 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||
p.ID, r.RequestURI, startDuration, totalTime)
|
||||
}
|
||||
|
||||
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||
func (p *Process) waitForCmd() {
|
||||
exitErr := p.cmd.Wait()
|
||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||
|
||||
if exitErr != nil {
|
||||
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||
}
|
||||
} else {
|
||||
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentState := p.CurrentState()
|
||||
switch currentState {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.state = StateStopped
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.state = StateStopped // force it to be in this state
|
||||
}
|
||||
close(p.cmdWaitChan)
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
func (p *Process) cmdStopUpstreamProcess() error {
|
||||
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||
|
||||
// this should never happen ...
|
||||
if p.cmd == nil || p.cmd.Process == nil {
|
||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
|
||||
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
stopCmd.Env = p.config.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import "syscall"
|
||||
|
||||
func (p *Process) terminateProcess() error {
|
||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func (p *Process) terminateProcess() error {
|
||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
||||
return cmd.Run()
|
||||
}
|
||||
+170
-17
@@ -2,9 +2,10 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,16 +14,25 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
discardLogger = NewLogMonitorWriter(io.Discard)
|
||||
debugLogger = NewLogMonitorWriter(os.Stdout)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// flip to help with debugging tests
|
||||
if false {
|
||||
debugLogger.SetLogLevel(LevelDebug)
|
||||
} else {
|
||||
debugLogger.SetLogLevel(LevelError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -58,7 +68,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -86,7 +96,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, discardLogger, discardLogger)
|
||||
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -96,8 +106,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed: ")
|
||||
}
|
||||
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
@@ -111,7 +121,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -153,7 +163,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
|
||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
@@ -180,7 +190,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, discardLogger, discardLogger)
|
||||
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
@@ -238,18 +248,14 @@ func TestProcess_SwapState(t *testing.T) {
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
@@ -257,7 +263,7 @@ func TestProcess_SwapState(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
@@ -290,7 +296,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
@@ -311,3 +317,150 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping Exit Interrupts Health Check test")
|
||||
}
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long concurrency limit test")
|
||||
}
|
||||
|
||||
expectedMessage := "concurrency_limit_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// only allow 1 concurrent request at a time
|
||||
config.ConcurrencyLimit = 1
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||
defer process.Stop()
|
||||
|
||||
// launch a goroutine first to take up the semaphore
|
||||
go func() {
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req1)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}()
|
||||
|
||||
// let the goroutine start
|
||||
<-time.After(time.Millisecond * 25)
|
||||
|
||||
denied := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, denied)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestProcess_StopImmediately(t *testing.T) {
|
||||
expectedMessage := "test_stop_immediate"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
}()
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
config := ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
process.gracefulStopTimeout = time.Second
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// StatusOK because that was already sent before the kill
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||
// then the unexpected EOF is sent after the kill
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
}()
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
config.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
pg.lastUsedProcess = modelID
|
||||
}
|
||||
pg.Unlock()
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
// make sure only one process is in the running state
|
||||
count := 0
|
||||
for _, process := range pg.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
+156
-171
@@ -26,17 +26,18 @@ const (
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentProcesses map[string]*Process
|
||||
ginEngine *gin.Engine
|
||||
config Config
|
||||
ginEngine *gin.Engine
|
||||
|
||||
// logging
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
func New(config Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
@@ -49,26 +50,43 @@ func New(config *Config) *ProxyManager {
|
||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||
case "debug":
|
||||
proxyLogger.SetLogLevel(LevelDebug)
|
||||
upstreamLogger.SetLogLevel(LevelDebug)
|
||||
case "info":
|
||||
proxyLogger.SetLogLevel(LevelInfo)
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
case "warn":
|
||||
proxyLogger.SetLogLevel(LevelWarn)
|
||||
upstreamLogger.SetLogLevel(LevelWarn)
|
||||
case "error":
|
||||
proxyLogger.SetLogLevel(LevelError)
|
||||
upstreamLogger.SetLogLevel(LevelError)
|
||||
default:
|
||||
proxyLogger.SetLogLevel(LevelInfo)
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
currentProcesses: make(map[string]*Process),
|
||||
ginEngine: gin.New(),
|
||||
config: config,
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range config.Groups {
|
||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
@@ -179,63 +197,77 @@ func New(config *Config) *ProxyManager {
|
||||
|
||||
// Disable console color for testing
|
||||
gin.DisableConsoleColor()
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Run(addr ...string) error {
|
||||
return pm.ginEngine.Run(addr...)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
||||
// ServeHTTP implements http.Handler interface
|
||||
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) StopProcesses() {
|
||||
// StopProcesses acquires a lock and stops all running upstream processes.
|
||||
// This is the public method safe for concurrent calls.
|
||||
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
||||
// a complete shutdown, allowing for process replacement without full termination.
|
||||
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
pm.stopProcesses()
|
||||
}
|
||||
|
||||
// for internal usage
|
||||
func (pm *ProxyManager) stopProcesses() {
|
||||
if len(pm.currentProcesses) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
for _, processGroup := range pm.processGroups {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
go func(processGroup *ProcessGroup) {
|
||||
defer wg.Done()
|
||||
process.Stop()
|
||||
}(process)
|
||||
processGroup.StopProcesses(strategy)
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
pm.currentProcesses = make(map[string]*Process)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Shutdown is called to shutdown all upstream processes
|
||||
// when llama-swap is shutting down.
|
||||
// Shutdown stops all processes managed by this ProxyManager
|
||||
func (pm *ProxyManager) Shutdown() {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// shutdown process in parallel
|
||||
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pm.currentProcesses {
|
||||
// Send shutdown signal to all process in groups
|
||||
for _, processGroup := range pm.processGroups {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
go func(processGroup *ProcessGroup) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
processGroup.Shutdown()
|
||||
}(processGroup)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||
for groupId, otherGroup := range pm.processGroups {
|
||||
if groupId != processGroup.id && !otherGroup.persistent {
|
||||
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := []interface{}{}
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
@@ -259,85 +291,12 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
|
||||
if profileName != "" {
|
||||
if _, found := pm.config.Profiles[profileName]; !found {
|
||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(modelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
// check if model is part of the profile
|
||||
if profileName != "" {
|
||||
found := false
|
||||
for _, item := range pm.config.Profiles[profileName] {
|
||||
if item == realModelName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
||||
}
|
||||
}
|
||||
|
||||
// exit early when already running, otherwise stop everything and swap
|
||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||
|
||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// stop all running models
|
||||
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
|
||||
pm.stopProcesses()
|
||||
if profileName == "" {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
} else {
|
||||
for _, modelName := range pm.config.Profiles[profileName] {
|
||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||
}
|
||||
|
||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||
processKey := ProcessKeyName(profileName, modelID)
|
||||
pm.currentProcesses[processKey] = process
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// requestedProcessKey should exist due to swap
|
||||
return pm.currentProcesses[requestedProcessKey], nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
|
||||
@@ -346,19 +305,21 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
} else {
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||
var html strings.Builder
|
||||
|
||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><a href=\"/unload\">Unload all models</a><ul>")
|
||||
|
||||
// Extract keys and sort them
|
||||
var modelIDs []string
|
||||
@@ -373,7 +334,31 @@ func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
||||
// Get process state
|
||||
processGroup := pm.findGroupByModelName(modelID)
|
||||
var state string
|
||||
if processGroup != nil {
|
||||
process := processGroup.processes[modelID]
|
||||
if process != nil {
|
||||
var stateStr string
|
||||
switch process.CurrentState() {
|
||||
case StateReady:
|
||||
stateStr = "Ready"
|
||||
case StateStarting:
|
||||
stateStr = "Starting"
|
||||
case StateStopping:
|
||||
stateStr = "Stopping"
|
||||
case StateShutdown:
|
||||
stateStr = "Shutdown"
|
||||
case StateStopped:
|
||||
stateStr = "Stopped"
|
||||
default:
|
||||
stateStr = "Unknown"
|
||||
}
|
||||
state = stateStr
|
||||
}
|
||||
}
|
||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a> - %s</li>", modelID, modelID, state))
|
||||
}
|
||||
html.WriteString("</ul></body></html>")
|
||||
c.Header("Content-Type", "text/html")
|
||||
@@ -390,49 +375,40 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
if process.config.UseModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
if profileName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Parse multipart form
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||
@@ -446,15 +422,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Swap to the requested model
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get profile name and model name from the requested model
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||
|
||||
// Copy all form values
|
||||
for key, values := range c.Request.MultipartForm.Value {
|
||||
@@ -462,10 +439,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
if process.config.UseModelName != "" {
|
||||
fieldValue = process.config.UseModelName
|
||||
} else if profileName != "" {
|
||||
fieldValue = modelName
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
fieldValue = requestedModel
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
@@ -526,8 +506,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.Header = c.Request.Header.Clone()
|
||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||
|
||||
// set the content length of the body
|
||||
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
process.ProxyRequest(c.Writer, modifiedReq)
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||
@@ -541,7 +529,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses()
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
|
||||
@@ -549,14 +537,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.Header("Content-Type", "application/json")
|
||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||
|
||||
for _, process := range pm.currentProcesses {
|
||||
|
||||
// Append the process ID and State (multiple entries if profiles are being used).
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
|
||||
for _, processGroup := range pm.processGroups {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put the results under the `running` key.
|
||||
@@ -567,15 +556,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||
}
|
||||
|
||||
func ProcessKeyName(groupName, modelName string) string {
|
||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||
}
|
||||
|
||||
func splitRequestedModel(requestedModel string) (string, string) {
|
||||
profileName, modelName := "", requestedModel
|
||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||
profileName = requestedModel[:idx]
|
||||
modelName = requestedModel[idx+1:]
|
||||
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||
for _, group := range pm.processGroups {
|
||||
if group.HasMember(modelName) {
|
||||
return group
|
||||
}
|
||||
}
|
||||
return profileName, modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
+243
-328
@@ -8,85 +8,120 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
|
||||
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||
|
||||
}
|
||||
|
||||
// make sure there's only one loaded model
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileModel1 := ProcessKeyName("test", model1)
|
||||
profileModel2 := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1, model2},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model1"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
for modelID, requestedModel := range map[string]string{
|
||||
"model1": profileModel1,
|
||||
"model2": profileModel2,
|
||||
} {
|
||||
tests := []string{"model1", "model2"}
|
||||
for _, requestedModel := range tests {
|
||||
t.Run(requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), requestedModel)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// Test that a persistent group is not affected by the swapping behaviour of
|
||||
// other groups.
|
||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
// the forever group is persistent and should not be affected by model1
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// make requests to load all models, loading model1 should not affect model2
|
||||
tests := []string{"model2", "model1"}
|
||||
for _, requestedModel := range tests {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelID)
|
||||
assert.Contains(t, w.Body.String(), requestedModel)
|
||||
}
|
||||
|
||||
// make sure there's two loaded models
|
||||
assert.Len(t, proxy.currentProcesses, 2)
|
||||
_, exists := proxy.currentProcesses[profileModel1]
|
||||
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
|
||||
|
||||
_, exists = proxy.currentProcesses[profileModel2]
|
||||
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
|
||||
assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// When a request for a different model comes in ProxyManager should wait until
|
||||
@@ -96,17 +131,18 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
results := map[string]string{}
|
||||
|
||||
@@ -122,15 +158,16 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
|
||||
results[key] = w.Body.String()
|
||||
var response map[string]string
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
results[key] = response["responseMessage"]
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
@@ -146,13 +183,14 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
@@ -163,7 +201,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
// Check the response status code
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -213,50 +251,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||
|
||||
model1 := "path1/model1"
|
||||
model2 := "path2/model2"
|
||||
|
||||
profileMemberName := ProcessKeyName("test", model1)
|
||||
profileNonMemberName := ProcessKeyName("test", model2)
|
||||
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
model1: getTestSimpleResponderConfig("model1"),
|
||||
model2: getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {model1},
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
// actual member of profile
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
}
|
||||
|
||||
// actual model, but non-member will 404
|
||||
{
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
@@ -268,23 +262,27 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2", "model3"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
"test": {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2", "model3"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Start all the processes
|
||||
var wg sync.WaitGroup
|
||||
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||
for _, modelName := range []string{"model1", "model2", "model3"} {
|
||||
wg.Add(1)
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
@@ -292,11 +290,10 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load
|
||||
proxy.HandlerFunc(w, req)
|
||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||
//fmt.Println(w.Code, w.Body.String())
|
||||
}(modelName)
|
||||
}
|
||||
|
||||
@@ -308,64 +305,43 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
proc, err := proxy.swapModel("model1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proc)
|
||||
|
||||
assert.Len(t, proxy.currentProcesses, 1)
|
||||
req := httptest.NewRequest("GET", "/unload", nil)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
req = httptest.NewRequest("GET", "/unload", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
assert.Len(t, proxy.currentProcesses, 0)
|
||||
}
|
||||
|
||||
// issue 62, strip profile slug from model name
|
||||
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "ok")
|
||||
// give it a bit of time to stop
|
||||
<-time.After(time.Millisecond * 250)
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Shared configuration
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
}
|
||||
LogLevel: "warn",
|
||||
})
|
||||
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
@@ -377,12 +353,12 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Create proxy once for all tests
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("no models loaded", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
@@ -400,13 +376,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Simulate browser call for the `/running` endpoint.
|
||||
req = httptest.NewRequest("GET", "/running", nil)
|
||||
w = httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
@@ -420,235 +396,132 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
})
|
||||
|
||||
t.Run("multiple models via profile", func(t *testing.T) {
|
||||
// Load more than one model.
|
||||
for _, model := range []string{"model1", "model2"} {
|
||||
profileModel := ProcessKeyName("test", model)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Simulate the browser call.
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
|
||||
// The JSON response must be valid.
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// The response should contain 2 models.
|
||||
assert.Len(t, response.Running, 2)
|
||||
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
}
|
||||
|
||||
// Iterate through the models and check their states as well.
|
||||
for _, entry := range response.Running {
|
||||
_, exists := expectedModels[entry.Model]
|
||||
assert.True(t, exists, "unexpected model %s", entry.Model)
|
||||
assert.Equal(t, "ready", entry.State)
|
||||
delete(expectedModels, entry.Model)
|
||||
}
|
||||
|
||||
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"TheExpectedModel"},
|
||||
},
|
||||
Models: map[string]ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
modelInput string
|
||||
expectModel string
|
||||
}{
|
||||
{
|
||||
name: "With Profile Prefix",
|
||||
modelInput: "test:TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
||||
},
|
||||
{
|
||||
name: "Without Profile Prefix",
|
||||
modelInput: "TheExpectedModel",
|
||||
expectModel: "TheExpectedModel", // Should remain the same
|
||||
},
|
||||
}
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("TheExpectedModel"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tc.modelInput))
|
||||
assert.NoError(t, err)
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
// Generate random content length between 10 and 20
|
||||
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||
content := make([]byte, contentLength)
|
||||
_, err = fw.Write(content)
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectModel, response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expectedProfile string
|
||||
expectedModel string
|
||||
}{
|
||||
{"no profile", "gpt-4", "", "gpt-4"},
|
||||
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
||||
{"only profile", "profile1:", "profile1", ""},
|
||||
{"empty model", ":gpt-4", "", "gpt-4"},
|
||||
{"empty profile", ":", "", ""},
|
||||
{"no split char", "gpt-4", "", "gpt-4"},
|
||||
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
||||
if profileName != tt.expectedProfile {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
if modelName != tt.expectedModel {
|
||||
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
||||
}
|
||||
|
||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||
func TestProxyManager_UseModelName(t *testing.T) {
|
||||
|
||||
upstreamModelName := "upstreamModel"
|
||||
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1"},
|
||||
},
|
||||
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
requestedModel string
|
||||
}{
|
||||
{"useModelName over rides requested model", "model1"},
|
||||
{"useModelName over rides requested profile:model", "test:model1"},
|
||||
}
|
||||
requestedModel := "model1"
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.HandlerFunc(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||
|
||||
})
|
||||
}
|
||||
// make sure the content length was set correctly
|
||||
// simple-responder will return the content length it got in the response
|
||||
body := w.Body.Bytes()
|
||||
contentLength := int(gjson.GetBytes(body, "h_content_length").Int())
|
||||
assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength)
|
||||
})
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) {
|
||||
// Create a buffer with multipart form data
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(tt.requestedModel))
|
||||
assert.NoError(t, err)
|
||||
// Add the model field
|
||||
fw, err := w.CreateFormField("model")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte(requestedModel))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
// Add a file field
|
||||
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||
assert.NoError(t, err)
|
||||
_, err = fw.Write([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
w.Close()
|
||||
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.HandlerFunc(rec, req)
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
// Verify the response
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, upstreamModelName, response["model"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogRequests: true,
|
||||
}
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -690,7 +563,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses()
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||
for k, v := range tt.requestHeaders {
|
||||
@@ -698,7 +571,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ginEngine.ServeHTTP(w, req)
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
@@ -708,3 +581,45 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_Upstream(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.Equal(t, "81", response["h_content_length"])
|
||||
assert.Equal(t, "model1", response["responseMessage"])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
#!/bin/sh
|
||||
# This script installs llama-swap on Linux.
|
||||
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
|
||||
|
||||
set -eu
|
||||
|
||||
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||
|
||||
status() { echo ">>> $*" >&2; }
|
||||
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||
|
||||
available() { command -v "$1" >/dev/null; }
|
||||
require() {
|
||||
_MISSING=''
|
||||
for TOOL in "$@"; do
|
||||
if ! available "$TOOL"; then
|
||||
_MISSING="$_MISSING $TOOL"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "$_MISSING"
|
||||
}
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
if ! available sudo; then
|
||||
error "This script requires superuser permissions. Please re-run as root."
|
||||
fi
|
||||
|
||||
SUDO="sudo"
|
||||
fi
|
||||
|
||||
NEEDS=$(require tee tar python3 mktemp)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
aarch64|arm64) ARCH="arm64" ;;
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
case "$KERN" in
|
||||
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||
*) ;;
|
||||
esac
|
||||
|
||||
download_binary() {
|
||||
ASSET_NAME="linux_$ARCH"
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
|
||||
PYTHON_SCRIPT=$(cat <<EOF
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
ASSET_NAME = "${ASSET_NAME}"
|
||||
|
||||
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
|
||||
data = json.load(resp)
|
||||
for asset in data.get("assets", []):
|
||||
if ASSET_NAME in asset.get("name", ""):
|
||||
url = asset["browser_download_url"]
|
||||
break
|
||||
else:
|
||||
print("ERROR: Matching asset not found.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
print("Downloading:", url, file=sys.stderr)
|
||||
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
|
||||
urllib.request.urlretrieve(url, output_path)
|
||||
print(output_path)
|
||||
EOF
|
||||
)
|
||||
|
||||
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
|
||||
if [ ! -f "$TARFILE" ]; then
|
||||
error "Failed to download binary."
|
||||
fi
|
||||
|
||||
status "Extracting to /usr/local/bin"
|
||||
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
|
||||
}
|
||||
download_binary
|
||||
|
||||
configure_systemd() {
|
||||
if ! id llama-swap >/dev/null 2>&1; then
|
||||
status "Creating llama-swap user..."
|
||||
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
|
||||
fi
|
||||
if getent group render >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to render group..."
|
||||
$SUDO usermod -a -G render llama-swap
|
||||
fi
|
||||
if getent group video >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to video group..."
|
||||
$SUDO usermod -a -G video llama-swap
|
||||
fi
|
||||
if getent group docker >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to docker group..."
|
||||
$SUDO usermod -a -G docker llama-swap
|
||||
fi
|
||||
|
||||
status "Adding current user to llama-swap group..."
|
||||
$SUDO usermod -a -G llama-swap "$(whoami)"
|
||||
|
||||
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||
status "Creating default config.yaml..."
|
||||
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
|
||||
# default 15s likely to fail for default models due to downloading models
|
||||
healthCheckTimeout: 60
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
cmd: |
|
||||
docker run
|
||||
--rm
|
||||
-p \${PORT}:8080
|
||||
--name qwen2.5
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
cmdStop: docker stop qwen2.5
|
||||
|
||||
"smollm2":
|
||||
cmd: |
|
||||
docker run
|
||||
--rm
|
||||
-p \${PORT}:8080
|
||||
--name smollm2
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
cmdStop: docker stop smollm2
|
||||
EOF
|
||||
fi
|
||||
|
||||
status "Creating llama-swap systemd service..."
|
||||
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
User=llama-swap
|
||||
Group=llama-swap
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||
case $SYSTEMCTL_RUNNING in
|
||||
running|degraded)
|
||||
status "Enabling and starting llama-swap service..."
|
||||
$SUDO systemctl daemon-reload
|
||||
$SUDO systemctl enable llama-swap
|
||||
|
||||
start_service() { $SUDO systemctl restart llama-swap; }
|
||||
trap start_service EXIT
|
||||
;;
|
||||
*)
|
||||
warning "systemd is not running"
|
||||
if [ "$IS_WSL2" = true ]; then
|
||||
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
if available systemctl; then
|
||||
configure_systemd
|
||||
fi
|
||||
|
||||
install_success() {
|
||||
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
|
||||
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
|
||||
status 'Install complete.'
|
||||
}
|
||||
|
||||
# WSL2 only supports GPUs via nvidia passthrough
|
||||
# so check for nvidia-smi to determine if GPU is available
|
||||
if [ "$IS_WSL2" = true ]; then
|
||||
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||
status "Nvidia GPU detected."
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
install_success
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/bin/sh
|
||||
# This script uninstalls llama-swap on Linux.
|
||||
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||
|
||||
status() { echo ">>> $*" >&2; }
|
||||
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||
|
||||
available() { command -v $1 >/dev/null; }
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
if ! available sudo; then
|
||||
error "This script requires superuser permissions. Please re-run as root."
|
||||
fi
|
||||
|
||||
SUDO="sudo"
|
||||
fi
|
||||
|
||||
configure_systemd() {
|
||||
status "Stopping llama-swap service..."
|
||||
$SUDO systemctl stop llama-swap
|
||||
|
||||
status "Disabling llama-swap service..."
|
||||
$SUDO systemctl disable llama-swap
|
||||
}
|
||||
if available systemctl; then
|
||||
configure_systemd
|
||||
fi
|
||||
|
||||
if available llama-swap; then
|
||||
status "Removing llama-swap binary..."
|
||||
$SUDO rm $(which llama-swap)
|
||||
fi
|
||||
|
||||
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||
while true; do
|
||||
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
|
||||
read answer
|
||||
case "$answer" in
|
||||
[Yy]* )
|
||||
$SUDO rm -r /usr/share/llama-swap
|
||||
break
|
||||
;;
|
||||
[Nn]* | "" )
|
||||
break
|
||||
;;
|
||||
* )
|
||||
echo "Invalid input. Please enter y or n."
|
||||
;;
|
||||
esac
|
||||
done
|
||||
fi
|
||||
|
||||
if id llama-swap >/dev/null 2>&1; then
|
||||
status "Removing llama-swap user..."
|
||||
$SUDO userdel llama-swap
|
||||
fi
|
||||
|
||||
if getent group llama-swap >/dev/null 2>&1; then
|
||||
status "Removing llama-swap group..."
|
||||
$SUDO groupdel llama-swap
|
||||
fi
|
||||
Reference in New Issue
Block a user