Compare commits
89 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5fad24c16f | |||
| 8404244fab | |||
| 712cd01081 | |||
| 1f7aa359b1 | |||
| b138d6cf25 | |||
| fb7c808082 | |||
| a7e640b0f7 | |||
| 593604dfdc | |||
| b8f888f864 | |||
| 192b2ae621 | |||
| b7f8cb5094 | |||
| a23da6eb57 | |||
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b | |||
| 671c1a5a7b | |||
| 52c0196e0f | |||
| 3201a68a04 | |||
| 3ac94ad20e | |||
| 60355bf74a | |||
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 | |||
| 3a1e9f81f1 | |||
| 72c883f36c | |||
| 1b04d034cf | |||
| 2e45f5692a | |||
| c97b80bdfe | |||
| ae3ef9bc39 | |||
| db6715bec3 | |||
| da5d9e8a6a | |||
| 84b667ca7a | |||
| 29657106fc |
@@ -0,0 +1,23 @@
|
|||||||
|
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||||
|
name: Close inactive issues
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "32 1 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-issues:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
days-before-issue-stale: 30
|
||||||
|
days-before-issue-close: 14
|
||||||
|
stale-issue-label: "stale"
|
||||||
|
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||||
|
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||||
|
days-before-pr-stale: -1
|
||||||
|
days-before-pr-close: -1
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
name: Build Containers
|
||||||
|
|
||||||
|
on:
|
||||||
|
# time has no specific meaning, trying to time it after
|
||||||
|
# the llama.cpp daily packages are published
|
||||||
|
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||||
|
schedule:
|
||||||
|
- cron: "37 5 * * *"
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
platform: [intel, cuda, vulkan, cpu, musa]
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Log in to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v2
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Run build-container
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||||
|
|
||||||
|
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||||
|
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||||
|
delete-untagged-containers:
|
||||||
|
needs: build-and-push
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/delete-package-versions@v5
|
||||||
|
with:
|
||||||
|
package-name: 'llama-swap'
|
||||||
|
package-type: 'container'
|
||||||
|
delete-only-untagged-versions: 'true'
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# This workflow will build a golang project
|
||||||
|
|
||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# necessary for testing proxy/Process swapping
|
||||||
|
- name: Create simple-responder
|
||||||
|
run: make simple-responder
|
||||||
|
|
||||||
|
- name: Test all
|
||||||
|
run: make test-all
|
||||||
@@ -5,6 +5,9 @@ on:
|
|||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
|
|||||||
+20
-1
@@ -6,6 +6,25 @@ builds:
|
|||||||
goos:
|
goos:
|
||||||
- linux
|
- linux
|
||||||
- darwin
|
- darwin
|
||||||
|
- freebsd
|
||||||
|
- windows
|
||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
ignore:
|
||||||
|
- goos: freebsd
|
||||||
|
goarch: arm64
|
||||||
|
- goos: windows
|
||||||
|
goarch: arm64
|
||||||
|
|
||||||
|
# use zip format for windows
|
||||||
|
archives:
|
||||||
|
- id: default
|
||||||
|
format: tar.gz
|
||||||
|
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
|
builds_info:
|
||||||
|
group: root
|
||||||
|
owner: root
|
||||||
|
format_overrides:
|
||||||
|
- goos: windows
|
||||||
|
format: zip
|
||||||
@@ -35,6 +35,11 @@ linux:
|
|||||||
@echo "Building Linux binary..."
|
@echo "Building Linux binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
|
||||||
|
# Build Windows binary
|
||||||
|
windows:
|
||||||
|
@echo "Building Windows binary..."
|
||||||
|
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||||
|
|
||||||
# for testing proxy.Process
|
# for testing proxy.Process
|
||||||
simple-responder:
|
simple-responder:
|
||||||
@echo "Building simple responder"
|
@echo "Building simple responder"
|
||||||
@@ -60,4 +65,4 @@ release:
|
|||||||
git tag "$$new_tag";
|
git tag "$$new_tag";
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean osx linux
|
.PHONY: all clean mac linux windows simple-responder
|
||||||
|
|||||||
@@ -1,56 +1,94 @@
|
|||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||

|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
# Introduction
|
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||||
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
|
||||||
|
|
||||||
Features:
|
## Features:
|
||||||
|
|
||||||
- ✅ Easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
- ✅ Easy to config: single yaml file
|
- ✅ Easy to config: single yaml file
|
||||||
- ✅ On-demand model switching
|
- ✅ On-demand model switching
|
||||||
|
- ✅ OpenAI API supported endpoints:
|
||||||
|
- `v1/completions`
|
||||||
|
- `v1/chat/completions`
|
||||||
|
- `v1/embeddings`
|
||||||
|
- `v1/rerank`
|
||||||
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
|
- ✅ llama-swap custom API endpoints
|
||||||
|
- `/log` - remote log monitoring
|
||||||
|
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
|
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
|
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||||
|
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||||
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
|
- ✅ Docker and Podman support
|
||||||
- ✅ Full control over server settings per model
|
- ✅ Full control over server settings per model
|
||||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
|
||||||
- ✅ Multiple GPU support
|
|
||||||
- ✅ Run multiple models at once with `profiles`
|
|
||||||
- ✅ Remote log monitoring at `/log`
|
|
||||||
- ✅ Automatic unloading of models from GPUs after timeout
|
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabblyAPI, etc)
|
|
||||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
|
||||||
|
|
||||||
## Releases
|
## How does llama-swap work?
|
||||||
|
|
||||||
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||||
|
|
||||||
### Building from source
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
1. Install golang for your system
|
|
||||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
|
||||||
1. `make clean all`
|
|
||||||
1. Binaries will be in `build/` subdirectory
|
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration is purposefully simple.
|
llama-swap's configuration is purposefully simple.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>But also very powerful ...</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||||
# Default (and minimum) is 15 seconds
|
# Default (and minimum) is 15 seconds
|
||||||
healthCheckTimeout: 60
|
healthCheckTimeout: 60
|
||||||
|
|
||||||
|
# Valid log levels: debug, info (default), warn, error
|
||||||
|
logLevel: info
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
# define valid model values and the upstream server start
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
# multiline for readability
|
||||||
|
cmd: >
|
||||||
|
llama-server --port 8999
|
||||||
|
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
|
# environment variables to pass to the command
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases names to use this model for
|
# aliases names to use this model for
|
||||||
aliases:
|
aliases:
|
||||||
- "gpt-4o-mini"
|
- "gpt-4o-mini"
|
||||||
- "gpt-3.5-turbo"
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
# check this path for an HTTP 200 OK before serving requests
|
# check this path for an HTTP 200 OK before serving requests
|
||||||
# default: /health to match llama.cpp
|
# default: /health to match llama.cpp
|
||||||
@@ -63,24 +101,26 @@ models:
|
|||||||
# default: 0 = never unload model
|
# default: 0 = never unload model
|
||||||
ttl: 60
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
# `useModelName` overrides the model name in the request
|
||||||
# environment variables to pass to the command
|
# and sends a specific name to the upstream server
|
||||||
env:
|
useModelName: "qwen:qwq"
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
|
||||||
|
|
||||||
# multiline for readability
|
|
||||||
cmd: >
|
|
||||||
llama-server --port 8999
|
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
|
||||||
proxy: http://127.0.0.1:8999
|
|
||||||
|
|
||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
# unlisted models do not show up in /v1/models or /upstream lists
|
||||||
# but they can still be requested as normal
|
# but they can still be requested as normal
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
|
||||||
unlisted: true
|
unlisted: true
|
||||||
|
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
# Docker Support (v26.1.4+ required!)
|
||||||
|
"docker-llama":
|
||||||
|
proxy: "http://127.0.0.1:9790"
|
||||||
|
cmd: >
|
||||||
|
docker run --name dockertest
|
||||||
|
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||||
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
|
# profiles eliminates swapping by running multiple models at the same time
|
||||||
#
|
#
|
||||||
# Tips:
|
# Tips:
|
||||||
# - each model must be listening on a unique address and port
|
# - each model must be listening on a unique address and port
|
||||||
@@ -88,23 +128,87 @@ models:
|
|||||||
# - the profile will load and unload all models in the profile at the same time
|
# - the profile will load and unload all models in the profile at the same time
|
||||||
profiles:
|
profiles:
|
||||||
coding:
|
coding:
|
||||||
- "qwen"
|
|
||||||
- "llama"
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
```
|
```
|
||||||
|
|
||||||
**Guides and examples**
|
### Use Case Examples
|
||||||
|
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||||
|
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
||||||
|
|
||||||
## Installation
|
## Configuration
|
||||||
|
|
||||||
|
llama-s
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
|
|
||||||
|
Docker is the quickest way to try out llama-swap:
|
||||||
|
|
||||||
|
```
|
||||||
|
# use CPU inference
|
||||||
|
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||||
|
|
||||||
|
|
||||||
|
# qwen2.5 0.5B
|
||||||
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer no-key" \
|
||||||
|
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||||
|
jq -r '.choices[0].message.content'
|
||||||
|
|
||||||
|
|
||||||
|
# SmolLM2 135M
|
||||||
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer no-key" \
|
||||||
|
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||||
|
jq -r '.choices[0].message.content'
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Docker images are nightly ...</summary>
|
||||||
|
|
||||||
|
They include:
|
||||||
|
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||||
|
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||||
|
- ROCm disabled until fixed in llama.cpp container
|
||||||
|
|
||||||
|
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||||
|
|
||||||
|
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
|
-v /path/to/models:/models \
|
||||||
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
|
ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||||
|
|
||||||
|
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||||
* _Note: Windows currently untested._
|
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||||
|
|
||||||
|
### Building from source
|
||||||
|
|
||||||
|
1. Install golang for your system
|
||||||
|
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||||
|
1. `make clean all`
|
||||||
|
1. Binaries will be in `build/` subdirectory
|
||||||
|
|
||||||
## Monitoring Logs
|
## Monitoring Logs
|
||||||
|
|
||||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||||
@@ -115,9 +219,15 @@ Of course, CLI access is also supported:
|
|||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
curl http://host/logs'
|
||||||
|
|
||||||
# streams logs
|
# streams combined logs
|
||||||
curl -Ns 'http://host/logs/stream'
|
curl -Ns 'http://host/logs/stream'
|
||||||
|
|
||||||
|
# just llama-swap's logs
|
||||||
|
curl -Ns 'http://host/logs/stream/proxy'
|
||||||
|
|
||||||
|
# just upstream's logs
|
||||||
|
curl -Ns 'http://host/logs/stream/upstream'
|
||||||
|
|
||||||
# stream and filter logs with linux pipes
|
# stream and filter logs with linux pipes
|
||||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||||
|
|
||||||
@@ -125,11 +235,18 @@ curl -Ns http://host/logs/stream | grep 'eval time'
|
|||||||
curl -Ns 'http://host/logs/stream?no-history'
|
curl -Ns 'http://host/logs/stream?no-history'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Do I need to use llama.cpp's server (llama-server)?
|
||||||
|
|
||||||
|
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||||
|
|
||||||
|
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||||
|
|
||||||
## Systemd Unit Files
|
## Systemd Unit Files
|
||||||
|
|
||||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||||
|
|
||||||
`/etc/systemd/system/llama-swap.service`
|
`/etc/systemd/system/llama-swap.service`
|
||||||
|
|
||||||
```
|
```
|
||||||
[Unit]
|
[Unit]
|
||||||
Description=llama-swap
|
Description=llama-swap
|
||||||
@@ -149,3 +266,7 @@ StartLimitInterval=30
|
|||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||||
|
|||||||
+12
-1
@@ -1,6 +1,9 @@
|
|||||||
# Seconds to wait for llama.cpp to be available to serve requests
|
# Seconds to wait for llama.cpp to be available to serve requests
|
||||||
# Default (and minimum): 15 seconds
|
# Default (and minimum): 15 seconds
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 90
|
||||||
|
|
||||||
|
# valid log levels: debug, info (default), warn, error
|
||||||
|
logLevel: debug
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"llama":
|
"llama":
|
||||||
@@ -50,6 +53,14 @@ models:
|
|||||||
--ctx-size 8192
|
--ctx-size 8192
|
||||||
--reranking
|
--reranking
|
||||||
|
|
||||||
|
# Docker Support (v26.1.4+ required!)
|
||||||
|
"dockertest":
|
||||||
|
proxy: "http://127.0.0.1:9790"
|
||||||
|
cmd: >
|
||||||
|
docker run --name dockertest
|
||||||
|
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
||||||
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
"simple":
|
"simple":
|
||||||
# example of setting environment variables
|
# example of setting environment variables
|
||||||
|
|||||||
Executable
+55
@@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
cd $(dirname "$0")
|
||||||
|
|
||||||
|
ARCH=$1
|
||||||
|
PUSH_IMAGES=${2:-false}
|
||||||
|
|
||||||
|
# List of allowed architectures
|
||||||
|
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||||
|
|
||||||
|
# Check if ARCH is in the allowed list
|
||||||
|
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||||
|
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if GITHUB_TOKEN is set and not empty
|
||||||
|
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||||
|
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# the most recent llama-swap tag
|
||||||
|
# have to strip out the 'v' due to .tar.gz file naming
|
||||||
|
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||||
|
|
||||||
|
if [ "$ARCH" == "cpu" ]; then
|
||||||
|
# cpu only containers just use the latest available
|
||||||
|
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||||
|
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||||
|
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||||
|
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||||
|
docker push ${CONTAINER_LATEST}
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||||
|
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||||
|
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||||
|
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||||
|
|
||||||
|
# Abort if LCPP_TAG is empty.
|
||||||
|
if [[ -z "$LCPP_TAG" ]]; then
|
||||||
|
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||||
|
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||||
|
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||||
|
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||||
|
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||||
|
docker push ${CONTAINER_TAG}
|
||||||
|
docker push ${CONTAINER_LATEST}
|
||||||
|
fi
|
||||||
|
fi
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
healthCheckTimeout: 300
|
||||||
|
logRequests: true
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
proxy: "http://127.0.0.1:9999"
|
||||||
|
cmd: >
|
||||||
|
/app/llama-server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
--port 9999
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
ARG BASE_TAG=server-cuda
|
||||||
|
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||||
|
|
||||||
|
# has to be after the FROM
|
||||||
|
ARG LS_VER=89
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
RUN \
|
||||||
|
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||||
|
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||||
|
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||||
|
|
||||||
|
COPY config.example.yaml /app/config.yaml
|
||||||
|
|
||||||
|
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||||
|
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
|
||||||
|
|
||||||
|
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
|
||||||
|
|
||||||
|
## Here's what you you need:
|
||||||
|
|
||||||
|
- aider - [installation docs](https://aider.chat/docs/install.html)
|
||||||
|
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
|
||||||
|
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
|
||||||
|
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
|
||||||
|
- 24GB VRAM video card
|
||||||
|
|
||||||
|
## Running aider
|
||||||
|
|
||||||
|
The goal is getting this command line to work:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
aider --architect \
|
||||||
|
--no-show-model-warnings \
|
||||||
|
--model openai/QwQ \
|
||||||
|
--editor-model openai/qwen-coder-32B \
|
||||||
|
--model-settings-file aider.model.settings.yml \
|
||||||
|
--openai-api-key "sk-na" \
|
||||||
|
--openai-api-base "http://10.0.1.24:8080/v1" \
|
||||||
|
```
|
||||||
|
|
||||||
|
Set `--openai-api-base` to the IP and port where your llama-swap is running.
|
||||||
|
|
||||||
|
## Create an aider model settings file
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# aider.model.settings.yml
|
||||||
|
|
||||||
|
#
|
||||||
|
# !!! important: model names must match llama-swap configuration names !!!
|
||||||
|
#
|
||||||
|
|
||||||
|
- name: "openai/QwQ"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.95
|
||||||
|
top_k: 40
|
||||||
|
presence_penalty: 0.1
|
||||||
|
repetition_penalty: 1
|
||||||
|
num_ctx: 16384
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
weak_model_name: "openai/qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
|
||||||
|
- name: "openai/qwen-coder-32B"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 20
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
editor_edit_format: editor-diff
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
```
|
||||||
|
|
||||||
|
## llama-swap configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
|
||||||
|
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
|
||||||
|
models:
|
||||||
|
"qwen-coder-32B":
|
||||||
|
proxy: "http://127.0.0.1:8999"
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||||
|
--ctx-size 16000
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
-ngl 99
|
||||||
|
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
|
"QwQ":
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
--ctx-size 32000
|
||||||
|
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||||
|
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
|
||||||
|
--min-p 0.01 --top-k 40 --top-p 0.95
|
||||||
|
-ngl 99
|
||||||
|
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced, Dual GPU Configuration
|
||||||
|
|
||||||
|
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
|
||||||
|
|
||||||
|
In llama-swap's configuration file:
|
||||||
|
|
||||||
|
1. add a `profiles` section with `aider` as the profile name
|
||||||
|
2. using the `env` field to specify the GPU IDs for each model
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
|
||||||
|
# Add a profile for aider
|
||||||
|
profiles:
|
||||||
|
aider:
|
||||||
|
- qwen-coder-32B
|
||||||
|
- QwQ
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen-coder-32B":
|
||||||
|
# manually set the GPU to run on
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
proxy: "http://127.0.0.1:8999"
|
||||||
|
cmd: /path/to/llama-server ...
|
||||||
|
|
||||||
|
"QwQ":
|
||||||
|
# manually set the GPU to run on
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=1"
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
cmd: /path/to/llama-server ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Append the profile tag, `aider:`, to the model names in the model settings file
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# aider.model.settings.yml
|
||||||
|
- name: "openai/aider:QwQ"
|
||||||
|
weak_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||||
|
|
||||||
|
- name: "openai/aider:qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||||
|
```
|
||||||
|
|
||||||
|
Run aider with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ aider --architect \
|
||||||
|
--no-show-model-warnings \
|
||||||
|
--model openai/aider:QwQ \
|
||||||
|
--editor-model openai/aider:qwen-coder-32B \
|
||||||
|
--config aider.conf.yml \
|
||||||
|
--model-settings-file aider.model.settings.yml
|
||||||
|
--openai-api-key "sk-na" \
|
||||||
|
--openai-api-base "http://10.0.1.24:8080/v1"
|
||||||
|
```
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# this makes use of llama-swap's profile feature to
|
||||||
|
# keep the architect and editor models in VRAM on different GPUs
|
||||||
|
|
||||||
|
- name: "openai/aider:QwQ"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.95
|
||||||
|
top_k: 40
|
||||||
|
presence_penalty: 0.1
|
||||||
|
repetition_penalty: 1
|
||||||
|
num_ctx: 16384
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
weak_model_name: "openai/aider:qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||||
|
|
||||||
|
- name: "openai/aider:qwen-coder-32B"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 20
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
editor_edit_format: editor-diff
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
- name: "openai/QwQ"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.95
|
||||||
|
top_k: 40
|
||||||
|
presence_penalty: 0.1
|
||||||
|
repetition_penalty: 1
|
||||||
|
num_ctx: 16384
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
weak_model_name: "openai/qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
|
||||||
|
- name: "openai/qwen-coder-32B"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 20
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
editor_edit_format: editor-diff
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
healthCheckTimeout: 300
|
||||||
|
logLevel: debug
|
||||||
|
|
||||||
|
profiles:
|
||||||
|
aider:
|
||||||
|
- qwen-coder-32B
|
||||||
|
- QwQ
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen-coder-32B":
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
aliases:
|
||||||
|
- coder
|
||||||
|
proxy: "http://127.0.0.1:8999"
|
||||||
|
|
||||||
|
# set appropriate paths for your environment
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||||
|
--ctx-size 16000
|
||||||
|
--ctx-size-draft 16000
|
||||||
|
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
|
||||||
|
-ngl 99 -ngld 99
|
||||||
|
--draft-max 16 --draft-min 4 --draft-p-min 0.4
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
"QwQ":
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=1"
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
|
||||||
|
# set appropriate paths for your environment
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 9503
|
||||||
|
--flash-attn --metrics
|
||||||
|
--slots
|
||||||
|
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
--ctx-size 32000
|
||||||
|
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||||
|
--temp 0.6
|
||||||
|
--repeat-penalty 1.1
|
||||||
|
--dry-multiplier 0.5
|
||||||
|
--min-p 0.01
|
||||||
|
--top-k 40
|
||||||
|
--top-p 0.95
|
||||||
|
-ngl 99 -ngld 99
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# Restart llama-swap on config change
|
||||||
|
|
||||||
|
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# A simple watch and restart llama-swap when its configuration
|
||||||
|
# file changes. Useful for trying out configuration changes
|
||||||
|
# without manually restarting the server each time.
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "Usage: $0 <path to config.yaml>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
# Start the process again
|
||||||
|
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||||
|
PID=$!
|
||||||
|
echo "Started llama-swap with PID $PID"
|
||||||
|
|
||||||
|
# Wait for modifications in the specified directory or file
|
||||||
|
inotifywait -e modify "$1"
|
||||||
|
|
||||||
|
# Check if process exists before sending signal
|
||||||
|
if kill -0 $PID 2>/dev/null; then
|
||||||
|
echo "Sending SIGTERM to $PID"
|
||||||
|
kill -SIGTERM $PID
|
||||||
|
wait $PID
|
||||||
|
else
|
||||||
|
echo "Process $PID no longer exists"
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage and output example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ ./watch-and-restart.sh config.yaml
|
||||||
|
Started llama-swap with PID 495455
|
||||||
|
Setting up watches.
|
||||||
|
Watches established.
|
||||||
|
llama-swap listening on :1867
|
||||||
|
Sending SIGTERM to 495455
|
||||||
|
Shutting down llama-swap
|
||||||
|
Started llama-swap with PID 495486
|
||||||
|
Setting up watches.
|
||||||
|
Watches established.
|
||||||
|
llama-swap listening on :1867
|
||||||
|
```
|
||||||
@@ -3,7 +3,11 @@ module github.com/mostlygeek/llama-swap
|
|||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/gin-gonic/gin v1.10.0
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
|
github.com/tidwall/gjson v1.18.0
|
||||||
|
github.com/tidwall/sjson v1.2.5
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,12 +19,10 @@ require (
|
|||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
@@ -29,12 +31,14 @@ require (
|
|||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0 // indirect
|
golang.org/x/crypto v0.36.0 // indirect
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.28.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
@@ -68,18 +78,30 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
|||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
|
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||||
|
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||||
|
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||||
|
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
|
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||||
|
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|||||||
BIN
Binary file not shown.
|
After Width: | Height: | Size: 351 KiB |
@@ -4,6 +4,8 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
@@ -39,6 +41,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
proxyManager := proxy.New(config)
|
||||||
|
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
<-sigChan
|
||||||
|
fmt.Println("Shutting down llama-swap")
|
||||||
|
proxyManager.Shutdown()
|
||||||
|
os.Exit(0)
|
||||||
|
}()
|
||||||
|
|
||||||
fmt.Println("llama-swap listening on " + *listenStr)
|
fmt.Println("llama-swap listening on " + *listenStr)
|
||||||
if err := proxyManager.Run(*listenStr); err != nil {
|
if err := proxyManager.Run(*listenStr); err != nil {
|
||||||
fmt.Printf("Server error: %v\n", err)
|
fmt.Printf("Server error: %v\n", err)
|
||||||
|
|||||||
@@ -12,12 +12,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
port := flag.String("port", "8080", "port to listen on")
|
port := flag.String("port", "8080", "port to listen on")
|
||||||
|
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||||
|
|
||||||
// Define a command-line flag for the response message
|
// Define a command-line flag for the response message
|
||||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||||
@@ -41,11 +43,70 @@ func main() {
|
|||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// for issue #62 to check model name strips profile slug
|
||||||
|
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||||
|
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||||
|
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Request.Body.Close()
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
if modelName != *expectedModel {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// issue #41
|
||||||
|
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||||
|
// Parse the multipart form
|
||||||
|
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the model from the form values
|
||||||
|
model := c.Request.FormValue("model")
|
||||||
|
|
||||||
|
if model == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the file from the form
|
||||||
|
file, _, err := c.Request.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Read the file content to get its size
|
||||||
|
fileBytes, err := io.ReadAll(file)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileSize := len(fileBytes)
|
||||||
|
|
||||||
|
// Return a JSON response with the model and transcription text including file size
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
|
"model": model,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
r.GET("/slow-respond", func(c *gin.Context) {
|
r.GET("/slow-respond", func(c *gin.Context) {
|
||||||
echo := c.Query("echo")
|
echo := c.Query("echo")
|
||||||
delay := c.Query("delay")
|
delay := c.Query("delay")
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type ModelConfig struct {
|
|||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
UnloadAfter int `yaml:"ttl"`
|
UnloadAfter int `yaml:"ttl"`
|
||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
|
UseModelName string `yaml:"useModelName"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
@@ -25,6 +26,8 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
LogRequests bool `yaml:"logRequests"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
Models map[string]ModelConfig `yaml:"models"`
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>llama-swap</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>llama-swap</h1>
|
||||||
|
<p>
|
||||||
|
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
||||||
|
</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
+228
-23
@@ -12,42 +12,247 @@
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
font-family: "Courier New", Courier, monospace;
|
font-family: "Courier New", Courier, monospace;
|
||||||
}
|
}
|
||||||
#log-stream {
|
.log-container {
|
||||||
|
display: flex;
|
||||||
flex: 1;
|
flex: 1;
|
||||||
margin: 1em;
|
gap: 0.5em;
|
||||||
padding: 10px;
|
margin: 0.5em;
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
.log-column {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
transition: flex 0.3s ease;
|
||||||
|
}
|
||||||
|
.log-column.minimized {
|
||||||
|
flex: 0.1;
|
||||||
|
max-width: 50px;
|
||||||
|
border: 1px solid #777;
|
||||||
|
color: green;
|
||||||
|
}
|
||||||
|
.log-controls {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr auto;
|
||||||
|
gap: 0.5em;
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
}
|
||||||
|
.log-controls input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 4px;
|
||||||
|
}
|
||||||
|
.log-controls input:focus {
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
.log-stream {
|
||||||
|
flex: 1;
|
||||||
|
padding: 1em;
|
||||||
background: #f4f4f4;
|
background: #f4f4f4;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
white-space: pre-wrap; /* Ensures line wrapping */
|
white-space: pre-wrap;
|
||||||
word-wrap: break-word; /* Ensures long words wrap */
|
word-wrap: break-word;
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.regex-error {
|
||||||
|
background-color: #ff0000 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Make headers clickable and show pointer cursor */
|
||||||
|
h2 {
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
margin: 0 0 0.5em 0;
|
||||||
|
padding: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
h2:hover {
|
||||||
|
background-color: rgba(0, 0, 0, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dark mode styles */
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
body {
|
||||||
|
background-color: #333;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-stream {
|
||||||
|
background: #444;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-controls input {
|
||||||
|
background: #555;
|
||||||
|
color: #fff;
|
||||||
|
border: 1px solid #777;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-controls button {
|
||||||
|
background: #555;
|
||||||
|
color: #fff;
|
||||||
|
border: 1px solid #777;
|
||||||
|
}
|
||||||
|
|
||||||
|
h2:hover {
|
||||||
|
background-color: rgba(255, 255, 255, 0.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Hide content when minimized */
|
||||||
|
.log-column.minimized .log-controls,
|
||||||
|
.log-column.minimized .log-stream {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-column.minimized h2 {
|
||||||
|
writing-mode: vertical-rl;
|
||||||
|
text-orientation: mixed;
|
||||||
|
transform: rotate(180deg);
|
||||||
|
white-space: nowrap;
|
||||||
|
margin: auto;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<pre id="log-stream">Waiting for logs...
|
<div class="log-container">
|
||||||
</pre>
|
<div class="log-column">
|
||||||
|
<h2>Proxy Logs</h2>
|
||||||
|
<div class="log-controls">
|
||||||
|
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
|
||||||
|
<button id="proxy-clear-button">clear</button>
|
||||||
|
</div>
|
||||||
|
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
|
||||||
|
</div>
|
||||||
|
<div class="log-column minimized">
|
||||||
|
<h2>Upstream Logs</h2>
|
||||||
|
<div class="log-controls">
|
||||||
|
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
|
||||||
|
<button id="upstream-clear-button">clear</button>
|
||||||
|
</div>
|
||||||
|
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<script>
|
<script>
|
||||||
// Establish an EventSource connection to the SSE endpoint
|
class LogStream {
|
||||||
if (typeof(EventSource) !== "undefined") {
|
constructor(streamElement, filterInput, clearButton, endpoint) {
|
||||||
const eventSource = new EventSource("/logs/streamSSE");
|
this.streamElement = streamElement;
|
||||||
|
this.filterInput = filterInput;
|
||||||
|
this.clearButton = clearButton;
|
||||||
|
this.endpoint = endpoint;
|
||||||
|
this.logData = "";
|
||||||
|
this.regexFilter = null;
|
||||||
|
this.eventSource = null;
|
||||||
|
|
||||||
eventSource.onmessage = function(event) {
|
this.initialize();
|
||||||
// Append the new log message to the <pre> element
|
}
|
||||||
const logStream = document.getElementById('log-stream');
|
|
||||||
|
|
||||||
logStream.textContent += event.data;
|
initialize() {
|
||||||
|
this.filterInput.addEventListener('input', () => this.updateFilter());
|
||||||
|
this.clearButton.addEventListener('click', () => {
|
||||||
|
this.filterInput.value = "";
|
||||||
|
this.regexFilter = null;
|
||||||
|
this.render();
|
||||||
|
});
|
||||||
|
this.setupEventSource();
|
||||||
|
}
|
||||||
|
|
||||||
// Auto-scroll to the bottom
|
setupEventSource() {
|
||||||
logStream.scrollTop = logStream.scrollHeight;
|
if (typeof(EventSource) === "undefined") {
|
||||||
};
|
this.logData = "SSE Not supported by this browser.";
|
||||||
|
this.render();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
eventSource.onerror = function(err) {
|
const connect = () => {
|
||||||
console.error("EventSource failed:", err);
|
this.eventSource = new EventSource(this.endpoint);
|
||||||
};
|
|
||||||
} else {
|
this.eventSource.onmessage = (event) => {
|
||||||
console.error("SSE not supported by this browser.");
|
this.logData += event.data;
|
||||||
|
this.render();
|
||||||
|
};
|
||||||
|
|
||||||
|
this.eventSource.onerror = (err) => {
|
||||||
|
// Close the current connection
|
||||||
|
this.eventSource.close();
|
||||||
|
|
||||||
|
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
|
||||||
|
this.render();
|
||||||
|
|
||||||
|
// Attempt to reconnect after 5 seconds
|
||||||
|
setTimeout(() => {
|
||||||
|
this.logData += "Attempting to reconnect...\n";
|
||||||
|
this.render();
|
||||||
|
connect();
|
||||||
|
}, 5000);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initial connection
|
||||||
|
connect();
|
||||||
|
}
|
||||||
|
|
||||||
|
render() {
|
||||||
|
let content = this.logData;
|
||||||
|
|
||||||
|
if (this.regexFilter) {
|
||||||
|
const lines = content.split('\n');
|
||||||
|
const filteredLines = lines.filter(line => this.regexFilter.test(line));
|
||||||
|
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
|
||||||
|
}
|
||||||
|
|
||||||
|
this.streamElement.textContent = content;
|
||||||
|
this.streamElement.scrollTop = this.streamElement.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateFilter() {
|
||||||
|
const pattern = this.filterInput.value.trim();
|
||||||
|
this.filterInput.classList.remove('regex-error');
|
||||||
|
|
||||||
|
if (!pattern) {
|
||||||
|
this.regexFilter = null;
|
||||||
|
this.render();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.regexFilter = new RegExp(pattern);
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Invalid regex pattern:", e);
|
||||||
|
this.regexFilter = null;
|
||||||
|
this.filterInput.classList.add('regex-error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.render();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize both log streams
|
||||||
|
document.addEventListener('DOMContentLoaded', () => {
|
||||||
|
new LogStream(
|
||||||
|
document.getElementById('proxy-log-stream'),
|
||||||
|
document.getElementById('proxy-filter-input'),
|
||||||
|
document.getElementById('proxy-clear-button'),
|
||||||
|
"/logs/streamSSE/proxy"
|
||||||
|
);
|
||||||
|
|
||||||
|
new LogStream(
|
||||||
|
document.getElementById('upstream-log-stream'),
|
||||||
|
document.getElementById('upstream-filter-input'),
|
||||||
|
document.getElementById('upstream-clear-button'),
|
||||||
|
"/logs/streamSSE/upstream"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Initialize clickable headers
|
||||||
|
document.querySelectorAll('h2').forEach(header => {
|
||||||
|
header.addEventListener('click', () => {
|
||||||
|
const column = header.closest('.log-column');
|
||||||
|
column.classList.toggle('minimized');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "embed"
|
||||||
|
|
||||||
|
//go:embed html
|
||||||
|
var htmlFiles embed.FS
|
||||||
|
|
||||||
|
func getHTMLFile(path string) ([]byte, error) {
|
||||||
|
return htmlFiles.ReadFile("html/" + path)
|
||||||
|
}
|
||||||
@@ -2,11 +2,21 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
"container/ring"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type LogLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelDebug LogLevel = iota
|
||||||
|
LevelInfo
|
||||||
|
LevelWarn
|
||||||
|
LevelError
|
||||||
|
)
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
clients map[chan []byte]bool
|
clients map[chan []byte]bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -15,6 +25,10 @@ type LogMonitor struct {
|
|||||||
|
|
||||||
// typically this can be os.Stdout
|
// typically this can be os.Stdout
|
||||||
stdout io.Writer
|
stdout io.Writer
|
||||||
|
|
||||||
|
// logging levels
|
||||||
|
level LogLevel
|
||||||
|
prefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLogMonitor() *LogMonitor {
|
func NewLogMonitor() *LogMonitor {
|
||||||
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
|||||||
clients: make(map[chan []byte]bool),
|
clients: make(map[chan []byte]bool),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
|
level: LevelInfo,
|
||||||
|
prefix: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
w.prefix = prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
w.level = level
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||||
|
prefix := ""
|
||||||
|
if w.prefix != "" {
|
||||||
|
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||||
|
if level < w.level {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(w.formatMessage(level.String(), msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Debug(msg string) {
|
||||||
|
w.log(LevelDebug, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Info(msg string) {
|
||||||
|
w.log(LevelInfo, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Warn(msg string) {
|
||||||
|
w.log(LevelWarn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Error(msg string) {
|
||||||
|
w.log(LevelError, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
||||||
|
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
||||||
|
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
||||||
|
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
||||||
|
w.log(LevelError, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LogLevel) String() string {
|
||||||
|
switch l {
|
||||||
|
case LevelDebug:
|
||||||
|
return "DEBUG"
|
||||||
|
case LevelInfo:
|
||||||
|
return "INFO"
|
||||||
|
case LevelWarn:
|
||||||
|
return "WARN"
|
||||||
|
case LevelError:
|
||||||
|
return "ERROR"
|
||||||
|
default:
|
||||||
|
return "UNKNOWN"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+305
-165
@@ -17,19 +17,31 @@ import (
|
|||||||
type ProcessState string
|
type ProcessState string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
StateStopped ProcessState = ProcessState("stopped")
|
StateStopped ProcessState = ProcessState("stopped")
|
||||||
StateReady ProcessState = ProcessState("ready")
|
StateStarting ProcessState = ProcessState("starting")
|
||||||
StateFailed ProcessState = ProcessState("failed")
|
StateReady ProcessState = ProcessState("ready")
|
||||||
|
StateStopping ProcessState = ProcessState("stopping")
|
||||||
|
|
||||||
|
// failed a health check on start and will not be recovered
|
||||||
|
StateFailed ProcessState = ProcessState("failed")
|
||||||
|
|
||||||
|
// process is shutdown and will not be restarted
|
||||||
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
sync.Mutex
|
ID string
|
||||||
|
config ModelConfig
|
||||||
|
cmd *exec.Cmd
|
||||||
|
|
||||||
ID string
|
// for p.cmd.Wait() select { ... }
|
||||||
config ModelConfig
|
cmdWaitChan chan error
|
||||||
cmd *exec.Cmd
|
|
||||||
logMonitor *LogMonitor
|
processLogger *LogMonitor
|
||||||
healthCheckTimeout int
|
proxyLogger *LogMonitor
|
||||||
|
|
||||||
|
healthCheckTimeout int
|
||||||
|
healthCheckLoopInterval time.Duration
|
||||||
|
|
||||||
lastRequestHandled time.Time
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
@@ -37,31 +49,94 @@ type Process struct {
|
|||||||
state ProcessState
|
state ProcessState
|
||||||
|
|
||||||
inFlightRequests sync.WaitGroup
|
inFlightRequests sync.WaitGroup
|
||||||
|
|
||||||
|
// used to block on multiple start() calls
|
||||||
|
waitStarting sync.WaitGroup
|
||||||
|
|
||||||
|
// for managing shutdown state
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
logMonitor: logMonitor,
|
cmdWaitChan: make(chan error, 1),
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
processLogger: processLogger,
|
||||||
state: StateStopped,
|
proxyLogger: proxyLogger,
|
||||||
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
|
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||||
|
state: StateStopped,
|
||||||
|
shutdownCtx: ctx,
|
||||||
|
shutdownCancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the process and returns when it is ready
|
// LogMonitor returns the log monitor associated with the process.
|
||||||
func (p *Process) start() error {
|
func (p *Process) LogMonitor() *LogMonitor {
|
||||||
|
return p.processLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// custom error types for swapping state
|
||||||
|
var (
|
||||||
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
|
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||||
|
)
|
||||||
|
|
||||||
|
// swapState performs a compare and swap of the state atomically. It returns the current state
|
||||||
|
// and an error if the swap failed.
|
||||||
|
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
|
||||||
p.stateMutex.Lock()
|
p.stateMutex.Lock()
|
||||||
defer p.stateMutex.Unlock()
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
if p.state == StateReady {
|
if p.state != expectedState {
|
||||||
return nil
|
p.proxyLogger.Warnf("swapState() Unexpected current state %s, expected %s", p.state, expectedState)
|
||||||
|
return p.state, ErrExpectedStateMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.state == StateFailed {
|
if !isValidTransition(p.state, newState) {
|
||||||
return fmt.Errorf("process is in a failed state and can not be restarted")
|
p.proxyLogger.Warnf("swapState() Invalid state transition from %s to %s", p.state, newState)
|
||||||
|
return p.state, ErrInvalidStateTransition
|
||||||
|
}
|
||||||
|
|
||||||
|
p.state = newState
|
||||||
|
p.proxyLogger.Debugf("swapState() State transitioned from %s to %s", expectedState, newState)
|
||||||
|
return p.state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to encapsulate transition rules
|
||||||
|
func isValidTransition(from, to ProcessState) bool {
|
||||||
|
switch from {
|
||||||
|
case StateStopped:
|
||||||
|
return to == StateStarting
|
||||||
|
case StateStarting:
|
||||||
|
return to == StateReady || to == StateFailed || to == StateStopping
|
||||||
|
case StateReady:
|
||||||
|
return to == StateStopping
|
||||||
|
case StateStopping:
|
||||||
|
return to == StateStopped || to == StateShutdown
|
||||||
|
case StateFailed, StateShutdown:
|
||||||
|
return false // No transitions allowed from these states
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) CurrentState() ProcessState {
|
||||||
|
p.stateMutex.RLock()
|
||||||
|
defer p.stateMutex.RUnlock()
|
||||||
|
return p.state
|
||||||
|
}
|
||||||
|
|
||||||
|
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||||
|
// it is a private method because starting is automatic but stopping can be called
|
||||||
|
// at any time.
|
||||||
|
func (p *Process) start() error {
|
||||||
|
|
||||||
|
if p.config.Proxy == "" {
|
||||||
|
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
args, err := p.config.SanitizedCommand()
|
args, err := p.config.SanitizedCommand()
|
||||||
@@ -69,52 +144,129 @@ func (p *Process) start() error {
|
|||||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||||
|
if err == ErrExpectedStateMismatch {
|
||||||
|
// already starting, just wait for it to complete and expect
|
||||||
|
// it to be be in the Ready start after. If not, return an error
|
||||||
|
if curState == StateStarting {
|
||||||
|
p.waitStarting.Wait()
|
||||||
|
if state := p.CurrentState(); state == StateReady {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("process was already starting but wound up in state %v", state)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("processes was in state %v when start() was called", curState)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.waitStarting.Add(1)
|
||||||
|
defer p.waitStarting.Done()
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.Command(args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.logMonitor
|
p.cmd.Stdout = p.processLogger
|
||||||
p.cmd.Stderr = p.logMonitor
|
p.cmd.Stderr = p.processLogger
|
||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = p.config.Env
|
||||||
|
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
|
err, curState, swapErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("start() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture the exit error for later signaling
|
||||||
|
go func() {
|
||||||
|
exitErr := p.cmd.Wait()
|
||||||
|
p.proxyLogger.Debugf("cmd.Wait() returned for [%s] error: %v", p.ID, exitErr)
|
||||||
|
p.cmdWaitChan <- exitErr
|
||||||
|
}()
|
||||||
|
|
||||||
// One of three things can happen at this stage:
|
// One of three things can happen at this stage:
|
||||||
// 1. The command exits unexpectedly
|
// 1. The command exits unexpectedly
|
||||||
// 2. The health check fails
|
// 2. The health check fails
|
||||||
// 3. The health check passes
|
// 3. The health check passes
|
||||||
//
|
//
|
||||||
// only in the third case will the process be considered Ready to accept
|
// only in the third case will the process be considered Ready to accept
|
||||||
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
||||||
defer cancelHealthCheck(nil) // clean up
|
|
||||||
cmdWaitChan := make(chan error, 1)
|
|
||||||
healthCheckChan := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
checkStartTime := time.Now()
|
||||||
// possible cmd exits early
|
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||||
cmdWaitChan <- p.cmd.Wait()
|
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||||
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
if checkEndpoint != "none" {
|
||||||
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
// keep default behaviour
|
||||||
}()
|
if checkEndpoint == "" {
|
||||||
|
checkEndpoint = "/health"
|
||||||
select {
|
|
||||||
case err := <-cmdWaitChan:
|
|
||||||
p.state = StateFailed
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
|
||||||
}
|
}
|
||||||
cancelHealthCheck(err)
|
|
||||||
return err
|
proxyTo := p.config.Proxy
|
||||||
case err := <-healthCheckChan:
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.state = StateFailed
|
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||||
return err
|
}
|
||||||
|
|
||||||
|
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
||||||
|
context.Background(),
|
||||||
|
checkStartTime.Add(maxDuration),
|
||||||
|
)
|
||||||
|
defer cancelHealthCheck()
|
||||||
|
|
||||||
|
loop:
|
||||||
|
// Ready Check loop
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-checkDeadline.Done():
|
||||||
|
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||||
|
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||||
|
}
|
||||||
|
case <-p.shutdownCtx.Done():
|
||||||
|
return errors.New("health check interrupted due to shutdown")
|
||||||
|
case exitErr := <-p.cmdWaitChan:
|
||||||
|
if exitErr != nil {
|
||||||
|
p.proxyLogger.Warnf("upstream command exited prematurely with error: %v", exitErr)
|
||||||
|
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||||
|
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Warnf("upstream command exited prematurely with no error")
|
||||||
|
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
||||||
|
return fmt.Errorf("upstream command exited prematurely with no error AND state swap failed: %v, current state: %v", err, curState)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("upstream command exited prematurely with no error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
|
p.proxyLogger.Infof("Health check passed on %s", healthURL)
|
||||||
|
cancelHealthCheck()
|
||||||
|
break loop
|
||||||
|
} else {
|
||||||
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
|
endTime, _ := checkDeadline.Deadline()
|
||||||
|
ttl := time.Until(endTime)
|
||||||
|
p.proxyLogger.Infof("Connection refused on %s, giving up in %.0fs", healthURL, ttl.Seconds())
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
<-time.After(p.healthCheckLoopInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +277,7 @@ func (p *Process) start() error {
|
|||||||
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||||
|
|
||||||
for range time.Tick(time.Second) {
|
for range time.Tick(time.Second) {
|
||||||
if p.state != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,168 +285,152 @@ func (p *Process) start() error {
|
|||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
p.state = StateReady
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
return nil
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
// wait for any inflight requests before proceeding
|
// wait for any inflight requests before proceeding
|
||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
p.proxyLogger.Debugf("Stopping process [%s]", p.ID)
|
||||||
|
|
||||||
p.stateMutex.Lock()
|
// calling Stop() when state is invalid is a no-op
|
||||||
defer p.stateMutex.Unlock()
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
|
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
|
||||||
if p.state != StateReady {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stop the process with a graceful exit timeout
|
||||||
|
p.stopCommand(5 * time.Second)
|
||||||
|
|
||||||
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
|
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
|
// is in the state of starting, it will cancel it and shut it down
|
||||||
|
func (p *Process) Shutdown() {
|
||||||
|
p.shutdownCancel()
|
||||||
|
p.stopCommand(5 * time.Second)
|
||||||
|
p.state = StateShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
|
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||||
|
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||||
|
stopStartTime := time.Now()
|
||||||
|
defer func() {
|
||||||
|
p.proxyLogger.Debugf("Process [%s] stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
|
}()
|
||||||
|
|
||||||
|
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
||||||
|
defer cancelTimeout()
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
// this situation should never happen... but if it does just update the state
|
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
|
||||||
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
|
||||||
p.state = StateStopped
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pretty sure this stopping code needs some work for windows and
|
if err := p.terminateProcess(); err != nil {
|
||||||
// will be a source of pain in the future.
|
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
|
||||||
|
}
|
||||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sigtermNormal := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
sigtermNormal <- p.cmd.Wait()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-sigtermTimeout.Done():
|
case <-sigtermTimeout.Done():
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID)
|
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
|
||||||
p.cmd.Process.Kill()
|
p.cmd.Process.Kill()
|
||||||
p.cmd.Wait()
|
case err := <-p.cmdWaitChan:
|
||||||
case err := <-sigtermNormal:
|
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
||||||
|
// because if we make it here then the cmd has been successfully running and made it
|
||||||
|
// through the health check. There is a possibility that ithe cmd crashed after the health check
|
||||||
|
// succeeded but that's not a case llama-swap is handling for now.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() != "wait: no child processes" {
|
if errno, ok := err.(syscall.Errno); ok {
|
||||||
// possible that simple-responder for testing is just not
|
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
|
||||||
// existing right, so suppress those errors.
|
} else if exitError, ok := err.(*exec.ExitError); ok {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process for %s stopped with error > %v\n", p.ID, err)
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
}
|
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
|
||||||
}
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
}
|
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
|
||||||
p.state = StateStopped
|
} else {
|
||||||
}
|
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||||
|
|
||||||
func (p *Process) CurrentState() ProcessState {
|
|
||||||
p.stateMutex.RLock()
|
|
||||||
defer p.stateMutex.RUnlock()
|
|
||||||
return p.state
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
|
||||||
if p.config.Proxy == "" {
|
|
||||||
return fmt.Errorf("no upstream available to check /health")
|
|
||||||
}
|
|
||||||
|
|
||||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
|
||||||
|
|
||||||
if checkEndpoint == "none" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep default behaviour
|
|
||||||
if checkEndpoint == "" {
|
|
||||||
checkEndpoint = "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
|
||||||
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
|
||||||
defer cancel()
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
|
|
||||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
// check if the context was cancelled
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
err := context.Cause(ctx)
|
|
||||||
if !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait a bit longer for TCP connection issues
|
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
|
||||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
} else {
|
} else {
|
||||||
time.Sleep(time.Second)
|
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ttl < 0 {
|
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if ttl < 0 {
|
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// got a response but it was not an OK
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestBeginTime := time.Now()
|
||||||
|
var startDuration time.Duration
|
||||||
|
|
||||||
|
// prevent new requests from being made while stopping or irrecoverable
|
||||||
|
currentState := p.CurrentState()
|
||||||
|
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
||||||
|
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.lastRequestHandled = time.Now()
|
||||||
p.inFlightRequests.Done()
|
p.inFlightRequests.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// start the process on demand
|
||||||
if p.CurrentState() != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
|
beginStartTime := time.Now()
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
http.Error(w, errstr, http.StatusInternalServerError)
|
http.Error(w, errstr, http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@@ -333,4 +469,8 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalTime := time.Since(requestBeginTime)
|
||||||
|
p.proxyLogger.Debugf("Process [%s] request %s - start: %v, total: %v",
|
||||||
|
p.ID, r.RequestURI, startDuration, totalTime)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func (p *Process) terminateProcess() error {
|
||||||
|
return p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Process) terminateProcess() error {
|
||||||
|
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
||||||
|
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
+189
-13
@@ -2,7 +2,6 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,13 +12,26 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debugLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// flip to help with debugging tests
|
||||||
|
if false {
|
||||||
|
debugLogger.SetLogLevel(LevelDebug)
|
||||||
|
} else {
|
||||||
|
debugLogger.SetLogLevel(LevelError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
// Create a process
|
// Create a process
|
||||||
process := NewProcess("test-process", 5, config, logMonitor)
|
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
@@ -48,6 +60,32 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||||
|
// are all handled successfully, even though they all may ask for the process to .start()
|
||||||
|
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||||
|
|
||||||
|
expectedMessage := "testing91931"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(reqID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||||
|
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
}
|
||||||
|
|
||||||
// test that the automatic start returns the expected error type
|
// test that the automatic start returns the expected error type
|
||||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
@@ -57,17 +95,20 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// test that the process unloads after the TTL
|
|
||||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping long auto unload TTL test")
|
t.Skip("skipping long auto unload TTL test")
|
||||||
@@ -79,7 +120,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
config.UnloadAfter = 3 // seconds
|
config.UnloadAfter = 3 // seconds
|
||||||
assert.Equal(t, 3, config.UnloadAfter)
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// this should take 4 seconds
|
// this should take 4 seconds
|
||||||
@@ -111,7 +152,36 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
assert.Equal(t, StateStopped, process.CurrentState())
|
assert.Equal(t, StateStopped, process.CurrentState())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_LowTTLValue(t *testing.T) {
|
||||||
|
if true { // change this code to run this ...
|
||||||
|
t.Skip("skipping test, edit process_test.go to run it ")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := getTestSimpleResponderConfig("fast_ttl")
|
||||||
|
assert.Equal(t, 0, config.UnloadAfter)
|
||||||
|
config.UnloadAfter = 1 // second
|
||||||
|
assert.Equal(t, 1, config.UnloadAfter)
|
||||||
|
|
||||||
|
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
t.Logf("Waiting before sending request %d", i)
|
||||||
|
time.Sleep(1500 * time.Millisecond)
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("echo=test_%d", i)
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// issue #19
|
// issue #19
|
||||||
|
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||||
|
// requests. All HTTP requests in this test should complete successfully.
|
||||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping slow test")
|
t.Skip("skipping slow test")
|
||||||
@@ -119,7 +189,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
|
|
||||||
expectedMessage := "12345"
|
expectedMessage := "12345"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
results := map[string]string{
|
results := map[string]string{
|
||||||
@@ -135,8 +205,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(key string) {
|
go func(key string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// send a request that should take 5 * 200ms (1 second) to complete
|
// send a request where simple-responder is will wait 300ms before responding
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
// this will simulate an in-progress request.
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
@@ -152,9 +223,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
}(key)
|
}(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the requests in the middle
|
// Stop the process while requests are still being processed
|
||||||
go func() {
|
go func() {
|
||||||
<-time.After(500 * time.Millisecond)
|
<-time.After(150 * time.Millisecond)
|
||||||
process.Stop()
|
process.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -164,3 +235,108 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
assert.Equal(t, key, result)
|
assert.Equal(t, key, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_SwapState(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currentState ProcessState
|
||||||
|
expectedState ProcessState
|
||||||
|
newState ProcessState
|
||||||
|
expectedError error
|
||||||
|
expectedResult ProcessState
|
||||||
|
}{
|
||||||
|
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||||
|
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||||
|
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
||||||
|
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||||
|
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||||
|
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||||
|
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
|
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||||
|
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
||||||
|
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||||
|
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
||||||
|
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||||
|
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
||||||
|
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
||||||
|
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||||
|
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||||
|
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||||
|
p.state = test.currentState
|
||||||
|
|
||||||
|
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||||
|
if err != nil && test.expectedError == nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
} else if err == nil && test.expectedError != nil {
|
||||||
|
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||||
|
} else if err != nil && test.expectedError != nil {
|
||||||
|
if err.Error() != test.expectedError.Error() {
|
||||||
|
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultState != test.expectedResult {
|
||||||
|
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long shutdown test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "testing91931"
|
||||||
|
|
||||||
|
// make a config where the healthcheck will always fail because port is wrong
|
||||||
|
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||||
|
config.Proxy = "http://localhost:9998/test"
|
||||||
|
|
||||||
|
healthCheckTTLSeconds := 30
|
||||||
|
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||||
|
|
||||||
|
// make it a lot faster
|
||||||
|
process.healthCheckLoopInterval = time.Second
|
||||||
|
|
||||||
|
// start a goroutine to simulate a shutdown
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-time.After(time.Millisecond * 500)
|
||||||
|
process.Shutdown()
|
||||||
|
}()
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
// start the process, this is a blocking call
|
||||||
|
err := process.start()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||||
|
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping Exit Interrupts Health Check test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// should run and exit but interrupt the long checkHealthTimeout
|
||||||
|
checkHealthTimeout := 5
|
||||||
|
config := ModelConfig{
|
||||||
|
Cmd: "sleep 1",
|
||||||
|
Proxy: "http://127.0.0.1:9913",
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||||
|
process.healthCheckLoopInterval = time.Second // make it faster
|
||||||
|
err := process.start()
|
||||||
|
assert.Equal(t, "upstream command exited prematurely with no error", err.Error())
|
||||||
|
assert.Equal(t, process.CurrentState(), StateFailed)
|
||||||
|
}
|
||||||
|
|||||||
+360
-45
@@ -2,11 +2,12 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"embed"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,38 +15,119 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PROFILE_SPLIT_CHAR = ":"
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed html/favicon.ico
|
|
||||||
var faviconData []byte
|
|
||||||
|
|
||||||
//go:embed html/logs.html
|
|
||||||
var logsHTML []byte
|
|
||||||
|
|
||||||
// make sure embed is kept there by the IDE auto-package importer
|
|
||||||
var _ = embed.FS{}
|
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config *Config
|
config *Config
|
||||||
currentProcesses map[string]*Process
|
currentProcesses map[string]*Process
|
||||||
logMonitor *LogMonitor
|
|
||||||
ginEngine *gin.Engine
|
ginEngine *gin.Engine
|
||||||
|
|
||||||
|
// logging
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
upstreamLogger *LogMonitor
|
||||||
|
muxLogger *LogMonitor
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *Config) *ProxyManager {
|
func New(config *Config) *ProxyManager {
|
||||||
|
// set up loggers
|
||||||
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
|
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
|
|
||||||
|
if config.LogRequests {
|
||||||
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||||
|
case "debug":
|
||||||
|
proxyLogger.SetLogLevel(LevelDebug)
|
||||||
|
upstreamLogger.SetLogLevel(LevelDebug)
|
||||||
|
case "info":
|
||||||
|
proxyLogger.SetLogLevel(LevelInfo)
|
||||||
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
|
case "warn":
|
||||||
|
proxyLogger.SetLogLevel(LevelWarn)
|
||||||
|
upstreamLogger.SetLogLevel(LevelWarn)
|
||||||
|
case "error":
|
||||||
|
proxyLogger.SetLogLevel(LevelError)
|
||||||
|
upstreamLogger.SetLogLevel(LevelError)
|
||||||
|
default:
|
||||||
|
proxyLogger.SetLogLevel(LevelInfo)
|
||||||
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
|
}
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
currentProcesses: make(map[string]*Process),
|
currentProcesses: make(map[string]*Process),
|
||||||
logMonitor: NewLogMonitor(),
|
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
|
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
muxLogger: stdoutLogger,
|
||||||
|
upstreamLogger: upstreamLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
// Start timer
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// capture these because /upstream/:model rewrites them in c.Next()
|
||||||
|
clientIP := c.ClientIP()
|
||||||
|
method := c.Request.Method
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
|
// Process request
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// Stop timer
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
statusCode := c.Writer.Status()
|
||||||
|
bodySize := c.Writer.Size()
|
||||||
|
|
||||||
|
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||||
|
clientIP,
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
c.Request.Proto,
|
||||||
|
statusCode,
|
||||||
|
bodySize,
|
||||||
|
c.Request.UserAgent(),
|
||||||
|
duration,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// see: issue: #81, #77 and #42 for CORS issues
|
||||||
|
// respond with permissive OPTIONS for any endpoint
|
||||||
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
if c.Request.Method == "OPTIONS" {
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
|
|
||||||
|
// allow whatever the client requested by default
|
||||||
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||||
|
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
||||||
|
c.Header("Access-Control-Allow-Headers", sanitized)
|
||||||
|
} else {
|
||||||
|
c.Header(
|
||||||
|
"Access-Control-Allow-Headers",
|
||||||
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
c.Header("Access-Control-Max-Age", "86400")
|
||||||
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
@@ -55,18 +137,49 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// Support audio/speech endpoint
|
||||||
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||||
|
|
||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||||
|
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||||
|
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
||||||
|
|
||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
|
// Set the Content-Type header to text/html
|
||||||
|
c.Header("Content-Type", "text/html")
|
||||||
|
|
||||||
|
// Write the embedded HTML content to the response
|
||||||
|
htmlData, err := getHTMLFile("index.html")
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = c.Writer.Write(htmlData)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||||
c.Data(http.StatusOK, "image/x-icon", faviconData)
|
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
||||||
|
c.Data(http.StatusOK, "image/x-icon", data)
|
||||||
|
} else {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
@@ -96,13 +209,38 @@ func (pm *ProxyManager) stopProcesses() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stop Processes in parallel
|
||||||
|
var wg sync.WaitGroup
|
||||||
for _, process := range pm.currentProcesses {
|
for _, process := range pm.currentProcesses {
|
||||||
process.Stop()
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
process.Stop()
|
||||||
|
}(process)
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
pm.currentProcesses = make(map[string]*Process)
|
pm.currentProcesses = make(map[string]*Process)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shutdown is called to shutdown all upstream processes
|
||||||
|
// when llama-swap is shutting down.
|
||||||
|
func (pm *ProxyManager) Shutdown() {
|
||||||
|
pm.Lock()
|
||||||
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
// shutdown process in parallel
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
process.Shutdown()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := []interface{}{}
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
@@ -127,7 +265,7 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Encode the data as JSON and write it to the response writer
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,11 +275,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||||
profileName, modelName := "", requestedModel
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
|
||||||
profileName = requestedModel[:idx]
|
|
||||||
modelName = requestedModel[idx+1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[profileName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
@@ -155,23 +289,39 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if model is part of the profile
|
||||||
|
if profileName != "" {
|
||||||
|
found := false
|
||||||
|
for _, item := range pm.config.Profiles[profileName] {
|
||||||
|
if item == realModelName {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
// exit early when already running, otherwise stop everything and swap
|
||||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
|
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
|
||||||
return process, nil
|
return process, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop all running models
|
// stop all running models
|
||||||
|
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
|
||||||
pm.stopProcesses()
|
pm.stopProcesses()
|
||||||
|
|
||||||
if profileName == "" {
|
if profileName == "" {
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
} else {
|
} else {
|
||||||
@@ -182,7 +332,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
processKey := ProcessKeyName(profileName, modelID)
|
||||||
pm.currentProcesses[processKey] = process
|
pm.currentProcesses[processKey] = process
|
||||||
}
|
}
|
||||||
@@ -197,12 +347,12 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
requestedModel := c.Param("model_id")
|
requestedModel := c.Param("model_id")
|
||||||
|
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("model id required in path"))
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
} else {
|
} else {
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
c.Request.URL.Path = c.Param("upstreamPath")
|
||||||
@@ -238,34 +388,199 @@ func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
|||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||||
return
|
|
||||||
}
|
|
||||||
var requestBody map[string]interface{}
|
|
||||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("invalid JSON"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, ok := requestBody["model"].(string)
|
|
||||||
if !ok {
|
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("missing or invalid 'model' key"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(model); err != nil {
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
}
|
||||||
|
|
||||||
|
process, err := pm.swapModel(requestedModel)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
|
if process.config.UseModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
|
if profileName != "" {
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||||
c.Request.Header.Del("transfer-encoding")
|
if err != nil {
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
|
c.Request.Header.Del("transfer-encoding")
|
||||||
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
|
// Parse multipart form
|
||||||
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get model parameter from the form
|
||||||
|
requestedModel := c.Request.FormValue("model")
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap to the requested model
|
||||||
|
process, err := pm.swapModel(requestedModel)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get profile name and model name from the requested model
|
||||||
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
|
|
||||||
|
// Copy all form values
|
||||||
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
|
for _, value := range values {
|
||||||
|
fieldValue := value
|
||||||
|
// If this is the model field and we have a profile, use just the model name
|
||||||
|
if key == "model" {
|
||||||
|
if process.config.UseModelName != "" {
|
||||||
|
fieldValue = process.config.UseModelName
|
||||||
|
} else if profileName != "" {
|
||||||
|
fieldValue = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field, err := multipartWriter.CreateFormField(key)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy all files from the original request
|
||||||
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||||
|
for _, fileHeader := range fileHeaders {
|
||||||
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := fileHeader.Open()
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = io.Copy(formFile, file); err != nil {
|
||||||
|
file.Close()
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the multipart writer to finalize the form
|
||||||
|
if err := multipartWriter.Close(); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request with the reconstructed form data
|
||||||
|
modifiedReq, err := http.NewRequestWithContext(
|
||||||
|
c.Request.Context(),
|
||||||
|
c.Request.Method,
|
||||||
|
c.Request.URL.String(),
|
||||||
|
&requestBuffer,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the headers from the original request
|
||||||
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// Use the modified request for proxying
|
||||||
|
process.ProxyRequest(c.Writer, modifiedReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||||
|
acceptHeader := c.GetHeader("Accept")
|
||||||
|
|
||||||
|
if strings.Contains(acceptHeader, "application/json") {
|
||||||
|
c.JSON(statusCode, gin.H{"error": message})
|
||||||
|
} else {
|
||||||
|
c.String(statusCode, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
|
pm.StopProcesses()
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||||
|
context.Header("Content-Type", "application/json")
|
||||||
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||||
|
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
|
||||||
|
// Append the process ID and State (multiple entries if profiles are being used).
|
||||||
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
|
"model": process.ID,
|
||||||
|
"state": process.state,
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put the results under the `running` key.
|
||||||
|
response := gin.H{
|
||||||
|
"running": runningProcesses,
|
||||||
|
}
|
||||||
|
|
||||||
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProcessKeyName(groupName, modelName string) string {
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitRequestedModel(requestedModel string) (string, string) {
|
||||||
|
profileName, modelName := "", requestedModel
|
||||||
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
|
profileName = requestedModel[:idx]
|
||||||
|
modelName = requestedModel[idx+1:]
|
||||||
|
}
|
||||||
|
return profileName, modelName
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,21 +9,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
|
|
||||||
accept := c.GetHeader("Accept")
|
accept := c.GetHeader("Accept")
|
||||||
if strings.Contains(accept, "text/html") {
|
if strings.Contains(accept, "text/html") {
|
||||||
// Set the Content-Type header to text/html
|
// Set the Content-Type header to text/html
|
||||||
c.Header("Content-Type", "text/html")
|
c.Header("Content-Type", "text/html")
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
// Write the embedded HTML content to the response
|
||||||
_, err := c.Writer.Write(logsHTML)
|
logsHTML, err := getHTMLFile("logs.html")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to write response: %v", err))
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = c.Writer.Write(logsHTML)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
history := pm.logMonitor.GetHistory()
|
history := pm.muxLogger.GetHistory()
|
||||||
_, err := c.Writer.Write(history)
|
_, err := c.Writer.Write(history)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
@@ -37,13 +41,19 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.Header("Transfer-Encoding", "chunked")
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
ch := pm.logMonitor.Subscribe()
|
logMonitorId := c.Param("logMonitorID")
|
||||||
defer pm.logMonitor.Unsubscribe(ch)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch := logger.Subscribe()
|
||||||
|
defer logger.Unsubscribe(ch)
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
notify := c.Request.Context().Done()
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("Streaming unsupported"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,13 +61,9 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
// Send history first if not skipped
|
// Send history first if not skipped
|
||||||
|
|
||||||
if !skipHistory {
|
if !skipHistory {
|
||||||
history := pm.logMonitor.GetHistory()
|
history := logger.GetHistory()
|
||||||
if len(history) != 0 {
|
if len(history) != 0 {
|
||||||
_, err := c.Writer.Write(history)
|
c.Writer.Write(history)
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,7 +74,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
case msg := <-ch:
|
case msg := <-ch:
|
||||||
_, err := c.Writer.Write(msg)
|
_, err := c.Writer.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
// just break the loop if we can't write for some reason
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -84,15 +90,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
ch := pm.logMonitor.Subscribe()
|
logMonitorId := c.Param("logMonitorID")
|
||||||
defer pm.logMonitor.Unsubscribe(ch)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch := logger.Subscribe()
|
||||||
|
defer logger.Unsubscribe(ch)
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
notify := c.Request.Context().Done()
|
||||||
|
|
||||||
// Send history first if not skipped
|
// Send history first if not skipped
|
||||||
_, skipHistory := c.GetQuery("no-history")
|
_, skipHistory := c.GetQuery("no-history")
|
||||||
if !skipHistory {
|
if !skipHistory {
|
||||||
history := pm.logMonitor.GetHistory()
|
history := logger.GetHistory()
|
||||||
if len(history) != 0 {
|
if len(history) != 0 {
|
||||||
c.SSEvent("message", string(history))
|
c.SSEvent("message", string(history))
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
@@ -110,3 +122,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||||
|
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||||
|
var logger *LogMonitor
|
||||||
|
|
||||||
|
if logMonitorId == "" {
|
||||||
|
// maintain the default
|
||||||
|
logger = pm.muxLogger
|
||||||
|
} else if logMonitorId == "proxy" {
|
||||||
|
logger = pm.proxyLogger
|
||||||
|
} else if logMonitorId == "upstream" {
|
||||||
|
logger = pm.upstreamLogger
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||||
|
}
|
||||||
|
|
||||||
|
return logger, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -20,6 +22,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
|
LogLevel: "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
@@ -60,6 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
"test": {model1, model2},
|
"test": {model1, model2},
|
||||||
},
|
},
|
||||||
|
LogLevel: "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
@@ -101,6 +105,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
},
|
},
|
||||||
|
LogLevel: "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
@@ -151,6 +156,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
},
|
},
|
||||||
|
LogLevel: "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
@@ -210,3 +216,507 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
// Ensure all expected models were returned
|
// Ensure all expected models were returned
|
||||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ProfileNonMember(t *testing.T) {
|
||||||
|
|
||||||
|
model1 := "path1/model1"
|
||||||
|
model2 := "path2/model2"
|
||||||
|
|
||||||
|
profileMemberName := ProcessKeyName("test", model1)
|
||||||
|
profileNonMemberName := ProcessKeyName("test", model2)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
model1: getTestSimpleResponderConfig("model1"),
|
||||||
|
model2: getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {model1},
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
// actual member of profile
|
||||||
|
{
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "model1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// actual model, but non-member will 404
|
||||||
|
{
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_Shutdown(t *testing.T) {
|
||||||
|
// make broken model configurations
|
||||||
|
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||||
|
model1Config.Proxy = "http://localhost:10001/"
|
||||||
|
|
||||||
|
model2Config := getTestSimpleResponderConfigPort("model2", 9992)
|
||||||
|
model2Config.Proxy = "http://localhost:10002/"
|
||||||
|
|
||||||
|
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||||
|
model3Config.Proxy = "http://localhost:10003/"
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2", "model3"},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": model1Config,
|
||||||
|
"model2": model2Config,
|
||||||
|
"model3": model3Config,
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
|
||||||
|
// Start all the processes
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(modelName string) {
|
||||||
|
defer wg.Done()
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// send a request to trigger the proxy to load
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
|
||||||
|
//fmt.Println(w.Code, w.Body.String())
|
||||||
|
}(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-time.After(time.Second)
|
||||||
|
proxy.Shutdown()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_Unload(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
proc, err := proxy.swapModel("model1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, proc)
|
||||||
|
|
||||||
|
assert.Len(t, proxy.currentProcesses, 1)
|
||||||
|
req := httptest.NewRequest("GET", "/unload", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
assert.Len(t, proxy.currentProcesses, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue 62, strip profile slug from model name
|
||||||
|
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||||
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||||
|
|
||||||
|
// Shared configuration
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define a helper struct to parse the JSON response.
|
||||||
|
type RunningResponse struct {
|
||||||
|
Running []struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
State string `json:"state"`
|
||||||
|
} `json:"running"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create proxy once for all tests
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
t.Run("no models loaded", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var response RunningResponse
|
||||||
|
|
||||||
|
// Check if this is a valid JSON object.
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
|
||||||
|
// We should have an empty running array here.
|
||||||
|
assert.Empty(t, response.Running, "expected no running models")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single model loaded", func(t *testing.T) {
|
||||||
|
// Load just a model.
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// Simulate browser call for the `/running` endpoint.
|
||||||
|
req = httptest.NewRequest("GET", "/running", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
var response RunningResponse
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
|
||||||
|
// Check if we have a single array element.
|
||||||
|
assert.Len(t, response.Running, 1)
|
||||||
|
|
||||||
|
// Is this the right model?
|
||||||
|
assert.Equal(t, "model1", response.Running[0].Model)
|
||||||
|
|
||||||
|
// Is the model loaded?
|
||||||
|
assert.Equal(t, "ready", response.Running[0].State)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple models via profile", func(t *testing.T) {
|
||||||
|
// Load more than one model.
|
||||||
|
for _, model := range []string{"model1", "model2"} {
|
||||||
|
profileModel := ProcessKeyName("test", model)
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate the browser call.
|
||||||
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
var response RunningResponse
|
||||||
|
|
||||||
|
// The JSON response must be valid.
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
|
||||||
|
// The response should contain 2 models.
|
||||||
|
assert.Len(t, response.Running, 2)
|
||||||
|
|
||||||
|
expectedModels := map[string]struct{}{
|
||||||
|
"model1": {},
|
||||||
|
"model2": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through the models and check their states as well.
|
||||||
|
for _, entry := range response.Running {
|
||||||
|
_, exists := expectedModels[entry.Model]
|
||||||
|
assert.True(t, exists, "unexpected model %s", entry.Model)
|
||||||
|
assert.Equal(t, "ready", entry.State)
|
||||||
|
delete(expectedModels, entry.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since we deleted each model while testing for its validity we should have no more models in the response.
|
||||||
|
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"TheExpectedModel"},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
modelInput string
|
||||||
|
expectModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "With Profile Prefix",
|
||||||
|
modelInput: "test:TheExpectedModel",
|
||||||
|
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Without Profile Prefix",
|
||||||
|
modelInput: "TheExpectedModel",
|
||||||
|
expectModel: "TheExpectedModel", // Should remain the same
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create a buffer with multipart form data
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&b)
|
||||||
|
|
||||||
|
// Add the model field
|
||||||
|
fw, err := w.CreateFormField("model")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte(tc.modelInput))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a file field
|
||||||
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Generate random content length between 10 and 20
|
||||||
|
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||||
|
content := make([]byte, contentLength)
|
||||||
|
_, err = fw.Write(content)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
// Create the request with the multipart form data
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(rec, req)
|
||||||
|
|
||||||
|
// Verify the response
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var response map[string]string
|
||||||
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expectModel, response["model"])
|
||||||
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_SplitRequestedModel(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
requestedModel string
|
||||||
|
expectedProfile string
|
||||||
|
expectedModel string
|
||||||
|
}{
|
||||||
|
{"no profile", "gpt-4", "", "gpt-4"},
|
||||||
|
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
|
||||||
|
{"only profile", "profile1:", "profile1", ""},
|
||||||
|
{"empty model", ":gpt-4", "", "gpt-4"},
|
||||||
|
{"empty profile", ":", "", ""},
|
||||||
|
{"no split char", "gpt-4", "", "gpt-4"},
|
||||||
|
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
profileName, modelName := splitRequestedModel(tt.requestedModel)
|
||||||
|
if profileName != tt.expectedProfile {
|
||||||
|
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||||
|
}
|
||||||
|
if modelName != tt.expectedModel {
|
||||||
|
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
|
func TestProxyManager_UseModelName(t *testing.T) {
|
||||||
|
|
||||||
|
upstreamModelName := "upstreamModel"
|
||||||
|
|
||||||
|
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||||
|
modelConfig.UseModelName = upstreamModelName
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1"},
|
||||||
|
},
|
||||||
|
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": modelConfig,
|
||||||
|
},
|
||||||
|
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
description string
|
||||||
|
requestedModel string
|
||||||
|
}{
|
||||||
|
{"useModelName over rides requested model", "model1"},
|
||||||
|
{"useModelName over rides requested profile:model", "test:model1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), upstreamModelName)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
|
||||||
|
// Create a buffer with multipart form data
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&b)
|
||||||
|
|
||||||
|
// Add the model field
|
||||||
|
fw, err := w.CreateFormField("model")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte(tt.requestedModel))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a file field
|
||||||
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte("test"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
// Create the request with the multipart form data
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(rec, req)
|
||||||
|
|
||||||
|
// Verify the response
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var response map[string]string
|
||||||
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, upstreamModelName, response["model"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
requestHeaders map[string]string
|
||||||
|
expectedStatus int
|
||||||
|
expectedHeaders map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "OPTIONS with no headers",
|
||||||
|
method: "OPTIONS",
|
||||||
|
expectedStatus: http.StatusNoContent,
|
||||||
|
expectedHeaders: map[string]string{
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OPTIONS with specific headers",
|
||||||
|
method: "OPTIONS",
|
||||||
|
requestHeaders: map[string]string{
|
||||||
|
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusNoContent,
|
||||||
|
expectedHeaders: map[string]string{
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-OPTIONS request",
|
||||||
|
method: "GET",
|
||||||
|
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
|
||||||
|
for k, v := range tt.requestHeaders {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ginEngine.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||||
|
|
||||||
|
for header, expectedValue := range tt.expectedHeaders {
|
||||||
|
assert.Equal(t, expectedValue, w.Header().Get(header))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isTokenChar(r rune) bool {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||||
|
parts := strings.Split(headerValues, ",")
|
||||||
|
valid := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
for _, p := range parts {
|
||||||
|
v := strings.TrimSpace(p)
|
||||||
|
if v == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
validPart := true
|
||||||
|
for _, c := range v {
|
||||||
|
if !isTokenChar(c) {
|
||||||
|
validPart = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if validPart {
|
||||||
|
valid = append(valid, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(valid, ", ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace only",
|
||||||
|
input: " ",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single valid value",
|
||||||
|
input: "content-type",
|
||||||
|
expected: "content-type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid values",
|
||||||
|
input: "content-type, authorization, x-requested-with",
|
||||||
|
expected: "content-type, authorization, x-requested-with",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "values with extra spaces",
|
||||||
|
input: " content-type , authorization ",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "values with tabs",
|
||||||
|
input: "content-type,\tauthorization",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "values with invalid characters",
|
||||||
|
input: "content-type, auth\n, x-requested-with\r",
|
||||||
|
expected: "content-type, auth, x-requested-with",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty values in list",
|
||||||
|
input: "content-type,,authorization",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "leading and trailing commas",
|
||||||
|
input: ",content-type,authorization,",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed valid and invalid values",
|
||||||
|
input: "content-type, \x00invalid, x-requested-with",
|
||||||
|
expected: "content-type, x-requested-with",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case values",
|
||||||
|
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||||
|
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||||
|
tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user