Compare commits

...

58 Commits

Author SHA1 Message Date
Benson Wong 593604dfdc Show proxy and upstream logs in separate columns in logs UI 2025-04-05 10:36:54 -07:00
Benson Wong b8f888f864 Logging Improvements (#88)
This change revamps the internal logging architecture to be more flexible and descriptive. Previously all logs from both llama-swap and upstream services were mixed together. This makes it harder to troubleshoot and identify problems. This PR adds these new endpoints: 

- `/logs/stream/proxy` - just llama-swap's logs
- `/logs/stream/upstream` - stdout output from the upstream server
2025-04-04 21:01:33 -07:00
Benson Wong 192b2ae621 Remove no longer needed test 2025-04-04 14:46:01 -07:00
Benson Wong b7f8cb5094 Limit Access-Control-Allow-Origin to OPTIONS preflight requests #85 2025-04-04 14:44:35 -07:00
Benson Wong a23da6eb57 Sanitize CORS headers (#85)
Add sanitation step for `Access-Control-Allow-Headers` when echoing back user supplied headers
2025-04-01 08:43:53 -07:00
Grigorii Khvatskii 4c3aa40564 add graceful process termination on windows (#82) 2025-03-25 15:26:33 -07:00
Benson Wong 84e2c07a7e Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: 

- `Access-Control-Allow-Origin: *` is set for all requests 
- for pre-flight OPTIONS requests
  - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS`
  - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set
  - set `Access-Control-Max-Age: 86400` to that may improve performance 
- Add CORS tests to the proxy-manager
2025-03-25 15:24:43 -07:00
Benson Wong 680af28bcc Allow very permissive CORS headers (#77) 2025-03-20 15:50:21 -07:00
Benson Wong d94db42ffe fix bug checking incorrect error 2025-03-20 15:49:36 -07:00
Benson Wong 93cd83c55c add override for windows (#76) 2025-03-20 13:23:04 -07:00
Benson Wong 5565fca3ac add some badges to README 2025-03-19 11:25:06 -07:00
Benson Wong d625ab8d92 Refactor process state management (#70) (#73)
* add isValidStateTransition helper function
* Replace Process.setState() with Process.swapState()
* Refactor locking logic in Process
2025-03-15 17:14:03 -07:00
Benson Wong a3f82c140b tidy up config examples in README 2025-03-15 10:36:45 -07:00
Benson Wong 5c97299e7b Add support for sending a custom model name to upstream (#69) (#71)
* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
2025-03-14 21:07:52 -07:00
Benson Wong 671c1a5a7b update deps 2025-03-13 14:00:15 -07:00
Benson Wong 52c0196e0f clean up feature list in readme 2025-03-13 13:55:20 -07:00
Benson Wong 3201a68a04 Add /v1/audio/transcriptions support (#41)
* add support for /v1/audio/transcriptions
2025-03-13 13:49:39 -07:00
Florin-Gabriel Dumitru 3ac94ad20e Adds an endpoint '/running' (#61)
* Adds an endpoint '/running' that returns either an empty JSON object if no model has been loaded so far, or the last model loaded (model key) and it's current state (state key). Possible state values are: stopped, starting, ready and stopping.

* Improves the `/running` endpoint by allowing multiple entries under the `running` key within the JSON response.
Refactors the `/running` method name (listRunningProcessesHandler).
Removes the unlisted filter implementation.

* Adds tests for:
- no model loaded
- one model loaded
- multiple models loaded

* Adds simple comments.

* Simplified code structure as per 250313 comments on PR #65.

---------

Co-authored-by: FGDumitru|B <xelotx@gmail.com>
2025-03-13 13:42:59 -07:00
Benson Wong 60355bf74a fix some potentially confusing Process.start() comment 2025-03-11 11:00:45 -07:00
Benson Wong 9b2ed244e2 Improve Continuous integration and fix concurrency bugs (#66)
- improvements to the continuous GH actions
- fix edge case concurrency bugs with Process.start() and state transitions discovered setting up CI.
2025-03-11 10:39:14 -07:00
Benson Wong eeb72297f7 add first version of CI for go 2025-03-11 08:45:56 -07:00
Benson Wong eabfe70cc6 add GH action to close inactive issues 2025-03-09 19:51:48 -07:00
Benson Wong 29cd98878d better container build logic when upstream containers do not exist 2025-03-09 13:02:06 -07:00
Benson Wong b3d331da0d Properly strip profile name slug from models fixes (#62)
The profile slug in a model name, `profile:model`, is specific to
llama-swap. This strips `profile:` out of the model name request so
upstreams that expect just `model` work and do not require knowing about
the profile slug.
2025-03-09 12:41:52 -07:00
Benson Wong 62275e078d add examples to restart on config change #59 2025-03-06 10:50:29 -08:00
Benson Wong 88916059e1 add /unload to docs 2025-03-03 10:44:16 -08:00
Benson Wong 082d5d0fc5 Add /unload endpoint (#58) to unload all currently running models 2025-03-03 10:33:36 -08:00
Benson Wong 53338938bd increase health check to a minimum of 5 seconds 2025-03-03 10:04:08 -08:00
Benson Wong af653347ae Update README.md w/ starhistory graph 2025-02-27 16:43:34 -08:00
Benson Wong 1e25b44a06 add workflow_dispatch to release action 2025-02-18 17:27:43 -08:00
Benson Wong 0815bb4cc3 Add windows to goreleaser #54 2025-02-18 17:26:43 -08:00
daschiller 7187cfe52e add Windows build support to Makefile (#54) 2025-02-18 17:24:31 -08:00
Benson Wong 24089d2d9c remove "no musa container" note from README 2025-02-18 16:38:48 -08:00
Benson Wong ebabe55ff3 Delete untagged packages after build and push (#55) 2025-02-18 10:32:32 -08:00
Benson Wong 41a338297c deletion of untagged containers happen after build-and-push 2025-02-18 10:11:59 -08:00
Benson Wong 7e3353efeb add action step to remove untagged containers 2025-02-18 10:08:41 -08:00
Benson Wong 4ed58fb173 update container build action 2025-02-18 09:59:06 -08:00
Benson Wong f5a2be698d revert package src until new ggml-org has them 2025-02-15 18:23:58 -08:00
Benson Wong f5e6ec3b7a fix package src in containerfile 2025-02-15 18:20:35 -08:00
Benson Wong 3f462da146 switch package source from ggerganov to ggml-org 2025-02-15 18:18:49 -08:00
Benson Wong 48bd766536 Update README.md 2025-02-14 22:05:52 -08:00
Benson Wong 8d319da4dd improve README organization (i think...) 2025-02-14 15:59:12 -08:00
Benson Wong be7c502448 improve docs 2025-02-14 15:47:31 -08:00
Benson Wong 92336f00bf more container build fixes 2025-02-14 15:34:38 -08:00
Benson Wong ed2a50d9a6 fix bug in build-container.sh 2025-02-14 15:27:56 -08:00
Benson Wong 0acfdb9f78 update workflow to build cpu and disable musa 2025-02-14 15:26:59 -08:00
Benson Wong 96a8ea0241 add cpu docker container build 2025-02-14 15:25:45 -08:00
Benson Wong f20f2c9b7a add docs and container build improvements #43 2025-02-14 12:20:07 -08:00
Benson Wong 7a97c38828 enable parallel container built #46 2025-02-14 11:04:33 -08:00
Benson Wong 4885132565 more permissions futzing 2025-02-14 11:02:15 -08:00
Benson Wong 8b46a0b7f1 grant package:write to container workflow #46 2025-02-14 10:55:30 -08:00
Benson Wong 1b6736ec6f rename workflow for containers 2025-02-14 10:50:15 -08:00
Benson Wong ddc1ce031e fix container file name #46 2025-02-14 10:49:44 -08:00
Benson Wong 11d024bbaa just build cuda while debugging 2025-02-14 10:48:06 -08:00
Benson Wong 43e23c16dc add check for GITHUB_TOKEN #46 2025-02-14 10:47:25 -08:00
Benson Wong f9c8e763ba add execute bit on build-container.sh 2025-02-14 10:44:53 -08:00
Benson Wong d7e1bb9f7c add GITHUB_TOKEN to container build env 2025-02-14 10:43:44 -08:00
Benson Wong ab93460a8b first container code (#52) 2025-02-14 10:39:25 -08:00
27 changed files with 1791 additions and 341 deletions
+23
View File
@@ -0,0 +1,23 @@
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
name: Close inactive issues
on:
schedule:
- cron: "32 1 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
+46
View File
@@ -0,0 +1,46 @@
name: Build Containers
on:
# time has no specific meaning, trying to time it after
# the llama.cpp daily packages are published
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
schedule:
- cron: "37 5 * * *"
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
build-and-push:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [intel, cuda, vulkan, cpu, musa]
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Run build-container
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} true
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
delete-untagged-containers:
needs: build-and-push
runs-on: ubuntu-latest
steps:
- uses: actions/delete-package-versions@v5
with:
package-name: 'llama-swap'
package-type: 'container'
delete-only-untagged-versions: 'true'
+32
View File
@@ -0,0 +1,32 @@
# This workflow will build a golang project
name: CI
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Test all
run: make test-all
+3
View File
@@ -5,6 +5,9 @@ on:
tags:
- '*'
# Allows manual triggering of the workflow
workflow_dispatch:
permissions:
contents: write
+16 -1
View File
@@ -7,9 +7,24 @@ builds:
- linux
- darwin
- freebsd
- windows
goarch:
- amd64
- arm64
ignore:
- goos: freebsd
goarch: arm64
goarch: arm64
- goos: windows
goarch: arm64
# use zip format for windows
archives:
- id: default
format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
builds_info:
group: root
owner: root
format_overrides:
- goos: windows
format: zip
+6 -1
View File
@@ -35,6 +35,11 @@ linux:
@echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
# Build Windows binary
windows:
@echo "Building Windows binary..."
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
@@ -60,4 +65,4 @@ release:
git tag "$$new_tag";
# Phony targets
.PHONY: all clean mac linux
.PHONY: all clean mac linux windows simple-responder
+141 -42
View File
@@ -1,66 +1,94 @@
![llama-swap header image](header.jpeg)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file).
Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or build it yourself from source with `make clean all`.
## How does it work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If a server is already running it will stop it and start the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
## Features:
- ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ On-demand model switching
- ✅ Full control over server settings per model
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- ✅ Multiple GPU support
-Docker and Podman support
- ✅ Run multiple models at once with `profiles`
- ✅ Remote log monitoring at `/log`
- ✅ Automatic unloading of models from GPUs after timeout
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
-llama-swap custom API endpoints
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- ✅ Docker and Podman support
- ✅ Full control over server settings per model
## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap's configuration is purposefully simple.
```yaml
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
```
<details>
<summary>But also very powerful ...</summary>
```yaml
# Seconds to wait for llama.cpp to load and be ready to serve requests
# Default (and minimum) is 15 seconds
healthCheckTimeout: 60
# Write HTTP logs (useful for troubleshooting), defaults to false
logRequests: true
# Valid log levels: debug, info (default), warn, error
logLevel: info
# define valid model values and the upstream server start
models:
"llama":
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999
# aliases names to use this model for
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# check this path for an HTTP 200 OK before serving requests
# default: /health to match llama.cpp
@@ -73,22 +101,15 @@ models:
# default: 0 = never unload model
ttl: 60
"qwen":
# environment variables to pass to the command
env:
- "CUDA_VISIBLE_DEVICES=0"
# multiline for readability
cmd: >
llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999
# `useModelName` overrides the model name in the request
# and sends a specific name to the upstream server
useModelName: "qwen:qwq"
# unlisted models do not show up in /v1/models or /upstream lists
# but they can still be requested as normal
"qwen-unlisted":
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
unlisted: true
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker Support (v26.1.4+ required!)
"docker-llama":
@@ -99,7 +120,7 @@ models:
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# profiles make it easy to managing multi model (and gpu) configurations.
# profiles eliminates swapping by running multiple models at the same time
#
# Tips:
# - each model must be listening on a unique address and port
@@ -107,21 +128,78 @@ models:
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
- "qwen"
- "llama"
- "qwen-unlisted"
```
### Advanced Examples
### Use Case Examples
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
### Installation
## Configuration
llama-s
</details>
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker is the quickest way to try out llama-swap:
```
# use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
```
<details>
<summary>Docker images are nightly ...</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
</details>
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
* _Note: Windows currently untested._
1. Run the binary with `llama-swap --config path/to/config.yaml`
### Building from source
@@ -141,9 +219,15 @@ Of course, CLI access is also supported:
# sends up to the last 10KB of logs
curl http://host/logs'
# streams logs
# streams combined logs
curl -Ns 'http://host/logs/stream'
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
@@ -151,11 +235,18 @@ curl -Ns http://host/logs/stream | grep 'eval time'
curl -Ns 'http://host/logs/stream?no-history'
```
## Do I need to use llama.cpp's server (llama-server)?
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
## Systemd Unit Files
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
`/etc/systemd/system/llama-swap.service`
```
[Unit]
Description=llama-swap
@@ -175,3 +266,11 @@ StartLimitInterval=30
[Install]
WantedBy=multi-user.target
```
## Star History
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
</picture>
+3 -3
View File
@@ -1,9 +1,9 @@
# Seconds to wait for llama.cpp to be available to serve requests
# Default (and minimum): 15 seconds
healthCheckTimeout: 15
healthCheckTimeout: 90
# Log HTTP requests helpful for troubleshoot, defaults to False
logRequests: true
# valid log levels: debug, info (default), warn, error
logLevel: info
models:
"llama":
+55
View File
@@ -0,0 +1,55 @@
#!/bin/bash
cd $(dirname "$0")
ARCH=$1
PUSH_IMAGES=${2:-false}
# List of allowed architectures
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
# Check if ARCH is in the allowed list
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
exit 1
fi
# Check if GITHUB_TOKEN is set and not empty
if [[ -z "$GITHUB_TOKEN" ]]; then
echo "Error: GITHUB_TOKEN is not set or is empty."
exit 1
fi
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
+17
View File
@@ -0,0 +1,17 @@
healthCheckTimeout: 300
logRequests: true
models:
"qwen2.5":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port 9999
"smollm2":
proxy: "http://127.0.0.1:9999"
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
+16
View File
@@ -0,0 +1,16 @@
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
@@ -0,0 +1,51 @@
# Restart llama-swap on config change
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
```bash
#!/bin/bash
#
# A simple watch and restart llama-swap when its configuration
# file changes. Useful for trying out configuration changes
# without manually restarting the server each time.
if [ -z "$1" ]; then
echo "Usage: $0 <path to config.yaml>"
exit 1
fi
while true; do
# Start the process again
./llama-swap-linux-amd64 -config $1 -listen :1867 &
PID=$!
echo "Started llama-swap with PID $PID"
# Wait for modifications in the specified directory or file
inotifywait -e modify "$1"
# Check if process exists before sending signal
if kill -0 $PID 2>/dev/null; then
echo "Sending SIGTERM to $PID"
kill -SIGTERM $PID
wait $PID
else
echo "Process $PID no longer exists"
fi
sleep 1
done
```
## Usage and output example
```bash
$ ./watch-and-restart.sh config.yaml
Started llama-swap with PID 495455
Setting up watches.
Watches established.
llama-swap listening on :1867
Sending SIGTERM to 495455
Shutting down llama-swap
Started llama-swap with PID 495486
Setting up watches.
Watches established.
llama-swap listening on :1867
```
+10 -6
View File
@@ -3,7 +3,11 @@ module github.com/mostlygeek/llama-swap
go 1.23.0
require (
github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1
)
@@ -15,12 +19,10 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
@@ -29,12 +31,14 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+18
View File
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -68,20 +78,28 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+61
View File
@@ -12,12 +12,14 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with")
@@ -41,11 +43,70 @@ func main() {
c.String(200, *responseMessage)
})
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
// Get the model from the form values
model := c.Request.FormValue("model")
if model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
return
}
// Get the file from the form
file, _, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
defer file.Close()
// Read the file content to get its size
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
return
}
fileSize := len(fileBytes)
// Return a JSON response with the model and transcription text including file size
c.JSON(http.StatusOK, gin.H{
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
"model": model,
})
})
r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo")
delay := c.Query("delay")
+2
View File
@@ -17,6 +17,7 @@ type ModelConfig struct {
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
@@ -26,6 +27,7 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"`
+188 -75
View File
@@ -12,32 +12,65 @@
flex-direction: column;
font-family: "Courier New", Courier, monospace;
}
#log-controls {
margin: 0.5em;
.log-container {
display: flex;
align-items: center;
justify-content: space-between; /* Spaces out elements evenly */
}
#log-controls input {
flex: 1;
}
#log-controls input:focus {
outline: none; /* Ensures no outline is shown when the input is focused */
}
#log-stream {
flex: 1;
gap: 0.5em;
margin: 0.5em;
min-height: 0;
}
.log-column {
display: flex;
flex-direction: column;
flex: 1;
min-width: 0;
transition: flex 0.3s ease;
}
.log-column.minimized {
flex: 0.1;
max-width: 50px;
border: 1px solid #777;
color: green;
}
.log-controls {
display: grid;
grid-template-columns: 1fr auto;
gap: 0.5em;
margin-bottom: 0.5em;
}
.log-controls input {
width: 100%;
padding: 4px;
}
.log-controls input:focus {
outline: none;
}
.log-stream {
flex: 1;
padding: 1em;
background: #f4f4f4;
overflow-y: auto;
white-space: pre-wrap; /* Ensures line wrapping */
word-wrap: break-word; /* Ensures long words wrap */
white-space: pre-wrap;
word-wrap: break-word;
min-height: 0;
}
.regex-error {
background-color: #ff0000 !important;
}
/* Make headers clickable and show pointer cursor */
h2 {
cursor: pointer;
user-select: none;
margin: 0 0 0.5em 0;
padding: 0.5em;
}
h2:hover {
background-color: rgba(0, 0, 0, 0.05);
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
@@ -45,101 +78,181 @@
color: #fff;
}
#log-stream {
.log-stream {
background: #444;
color: #fff;
}
#log-controls input {
.log-controls input {
background: #555;
color: #fff;
border: 1px solid #777;
}
#log-controls button {
.log-controls button {
background: #555;
color: #fff;
border: 1px solid #777;
}
h2:hover {
background-color: rgba(255, 255, 255, 0.1);
}
}
/* Hide content when minimized */
.log-column.minimized .log-controls,
.log-column.minimized .log-stream {
display: none;
}
.log-column.minimized h2 {
writing-mode: vertical-rl;
text-orientation: mixed;
transform: rotate(180deg);
white-space: nowrap;
margin: auto;
}
</style>
</head>
<body>
<pre id="log-stream">Waiting for logs...</pre>
<div id="log-controls">
<input type="text" id="filter-input" placeholder="regex filter">
<button id="clear-button">clear</button>
<div class="log-container">
<div class="log-column">
<h2>Proxy Logs</h2>
<div class="log-controls">
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
<button id="proxy-clear-button">clear</button>
</div>
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
</div>
<div class="log-column minimized">
<h2>Upstream Logs</h2>
<div class="log-controls">
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
<button id="upstream-clear-button">clear</button>
</div>
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
</div>
</div>
<script>
const logStream = document.getElementById('log-stream');
const filterInput = document.getElementById('filter-input');
var logData = "";
let regexFilter = null;
class LogStream {
constructor(streamElement, filterInput, clearButton, endpoint) {
this.streamElement = streamElement;
this.filterInput = filterInput;
this.clearButton = clearButton;
this.endpoint = endpoint;
this.logData = "";
this.regexFilter = null;
this.eventSource = null;
function setupEventSource() {
if (typeof(EventSource) !== "undefined") {
const eventSource = new EventSource("/logs/streamSSE");
eventSource.onmessage = function(event) {
logData += event.data;
render()
};
eventSource.onerror = function(err) {
logData = "EventSource failed: " + err.message;
};
} else {
logData = "SSE Not supported by this browser."
this.initialize();
}
}
// poor-ai's react ¯\_(ツ)_/¯
function render() {
if (regexFilter) {
const lines = logData.split('\n');
const filteredLines = lines.filter(line => {
return regexFilter === null || regexFilter.test(line);
initialize() {
this.filterInput.addEventListener('input', () => this.updateFilter());
this.clearButton.addEventListener('click', () => {
this.filterInput.value = "";
this.regexFilter = null;
this.render();
});
if (filteredLines.length > 0) {
logStream.textContent = filteredLines.join('\n') + '\n';
} else {
logStream.textContent = "";
}
} else {
logStream.textContent = logData;
this.setupEventSource();
}
logStream.scrollTop = logStream.scrollHeight;
}
setupEventSource() {
if (typeof(EventSource) === "undefined") {
this.logData = "SSE Not supported by this browser.";
this.render();
return;
}
const connect = () => {
this.eventSource = new EventSource(this.endpoint);
this.eventSource.onmessage = (event) => {
this.logData += event.data;
this.render();
};
this.eventSource.onerror = (err) => {
// Close the current connection
this.eventSource.close();
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
this.render();
// Attempt to reconnect after 5 seconds
setTimeout(() => {
this.logData += "Attempting to reconnect...\n";
this.render();
connect();
}, 5000);
};
};
// Initial connection
connect();
}
render() {
let content = this.logData;
if (this.regexFilter) {
const lines = content.split('\n');
const filteredLines = lines.filter(line => this.regexFilter.test(line));
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
}
this.streamElement.textContent = content;
this.streamElement.scrollTop = this.streamElement.scrollHeight;
}
updateFilter() {
const pattern = this.filterInput.value.trim();
this.filterInput.classList.remove('regex-error');
if (!pattern) {
this.regexFilter = null;
this.render();
return;
}
function updateFilter() {
const pattern = filterInput.value.trim();
filterInput.classList.remove('regex-error');
if (pattern) {
try {
regexFilter = new RegExp(pattern);
this.regexFilter = new RegExp(pattern);
} catch (e) {
console.error("Invalid regex pattern:", e);
regexFilter = null;
filterInput.classList.add('regex-error');
return
this.regexFilter = null;
this.filterInput.classList.add('regex-error');
return;
}
} else {
regexFilter = null;
}
render();
this.render();
}
}
filterInput.addEventListener('input', updateFilter);
document.getElementById('clear-button').addEventListener('click', () => {
filterInput.value = "";
regexFilter = null;
render();
// Initialize both log streams
document.addEventListener('DOMContentLoaded', () => {
new LogStream(
document.getElementById('proxy-log-stream'),
document.getElementById('proxy-filter-input'),
document.getElementById('proxy-clear-button'),
"/logs/streamSSE/proxy"
);
new LogStream(
document.getElementById('upstream-log-stream'),
document.getElementById('upstream-filter-input'),
document.getElementById('upstream-clear-button'),
"/logs/streamSSE/upstream"
);
// Initialize clickable headers
document.querySelectorAll('h2').forEach(header => {
header.addEventListener('click', () => {
const column = header.closest('.log-column');
column.classList.toggle('minimized');
});
});
});
setupEventSource();
updateFilter();
</script>
</body>
</html>
+90
View File
@@ -2,11 +2,21 @@ package proxy
import (
"container/ring"
"fmt"
"io"
"os"
"sync"
)
type LogLevel int
const (
LevelDebug LogLevel = iota
LevelInfo
LevelWarn
LevelError
)
type LogMonitor struct {
clients map[chan []byte]bool
mu sync.RWMutex
@@ -15,6 +25,10 @@ type LogMonitor struct {
// typically this can be os.Stdout
stdout io.Writer
// logging levels
level LogLevel
prefix string
}
func NewLogMonitor() *LogMonitor {
@@ -26,6 +40,8 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
clients: make(map[chan []byte]bool),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
}
}
@@ -94,3 +110,77 @@ func (w *LogMonitor) broadcast(msg []byte) {
}
}
}
func (w *LogMonitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
+125 -108
View File
@@ -30,11 +30,15 @@ const (
)
type Process struct {
ID string
config ModelConfig
cmd *exec.Cmd
logMonitor *LogMonitor
healthCheckTimeout int
ID string
config ModelConfig
cmd *exec.Cmd
processLogger *LogMonitor
proxyLogger *LogMonitor
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
@@ -51,54 +55,68 @@ type Process struct {
shutdownCancel context.CancelFunc
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
return &Process{
ID: ID,
config: config,
cmd: nil,
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
ID: ID,
config: config,
cmd: nil,
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout,
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
}
}
func (p *Process) setState(newState ProcessState) error {
// enforce valid state transitions
invalidTransition := false
if p.state == StateStopped {
// stopped -> starting
if newState != StateStarting {
invalidTransition = true
}
} else if p.state == StateStarting {
// starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateReady {
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
// LogMonitor returns the log monitor associated with the process.
func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
ErrInvalidStateTransition = errors.New("invalid state transition")
)
// swapState performs a compare and swap of the state atomically. It returns the current state
// and an error if the swap failed.
func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != expectedState {
return p.state, ErrExpectedStateMismatch
}
if invalidTransition {
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
p.state = newState
return nil
return p.state, nil
}
// Helper function to encapsulate transition rules
func isValidTransition(from, to ProcessState) bool {
switch from {
case StateStopped:
return to == StateStarting
case StateStarting:
return to == StateReady || to == StateFailed || to == StateStopping
case StateReady:
return to == StateStopping
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateFailed, StateShutdown:
return false // No transitions allowed from these states
}
return false
}
func (p *Process) CurrentState() ProcessState {
@@ -116,47 +134,48 @@ func (p *Process) start() error {
return fmt.Errorf("can not start(), upstream proxy missing")
}
// wait for the other start() to complete
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if err := p.setState(StateStarting); err != nil {
return err
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
return fmt.Errorf("unable to get sanitized command: %v", err)
}
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
// already starting, just wait for it to complete and expect
// it to be be in the Ready start after. If not, return an error
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state == StateReady {
return nil
} else {
return fmt.Errorf("process was already starting but wound up in state %v", state)
}
} else {
return fmt.Errorf("processes was in state %v when start() was called", curState)
}
} else {
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
p.cmd = exec.Command(args[0], args[1:]...)
p.cmd.Stdout = p.logMonitor
p.cmd.Stderr = p.logMonitor
p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.processLogger
p.cmd.Env = p.config.Env
err = p.cmd.Start()
// Set process state to failed
if err != nil {
p.setState(StateFailed)
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr,
)
}
return fmt.Errorf("start() failed: %v", err)
}
@@ -191,31 +210,35 @@ func (p *Process) start() error {
)
defer cancelHealthCheck()
// Health check loop
loop:
// Ready Check loop
for {
select {
case <-checkDeadline.Done():
p.setState(StateFailed)
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
} else {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("Health check passed on %s", healthURL)
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
p.proxyLogger.Infof("Connection refused on %s, retrying in %.0fs", healthURL, ttl.Seconds())
} else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
p.proxyLogger.Infof("Health check error on %s, %v", healthURL, err)
}
}
}
<-time.After(time.Second)
<-time.After(p.healthCheckLoopInterval)
}
}
@@ -226,7 +249,7 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) {
if p.state != StateReady {
if p.CurrentState() != StateReady {
return
}
@@ -234,7 +257,8 @@ func (p *Process) start() error {
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
@@ -242,26 +266,28 @@ func (p *Process) start() error {
}()
}
return p.setState(StateReady)
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
} else {
return nil
}
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
// calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("Stop() Ready -> StateStopping err: %v, current state: %v", err, curState)
return
}
// stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second)
if err := p.setState(StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Infof("Stop() StateStopping -> StateStopped err: %v, current state: %v", err, curState)
}
}
@@ -269,19 +295,9 @@ func (p *Process) Stop() {
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
p.state = StateShutdown
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -296,31 +312,32 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
}()
if p.cmd == nil || p.cmd.Process == nil {
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
return
}
p.cmd.Process.Signal(syscall.SIGTERM)
if err := p.terminateProcess(); err != nil {
p.proxyLogger.Infof("Failed to gracefully terminate process [%s]: %v", p.ID, err)
}
select {
case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
p.cmd.Process.Kill()
case err := <-sigtermNormal:
if err != nil {
if errno, ok := err.(syscall.Errno); ok {
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
p.proxyLogger.Infof("Process [%s] stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") {
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
p.proxyLogger.Infof("Process [%s] interrupted OK", p.ID)
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
p.proxyLogger.Warnf("Process [%s] ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
}
} else {
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
p.proxyLogger.Errorf("Process [%s] exited >> %v", p.ID, err)
}
}
}
+9
View File
@@ -0,0 +1,9 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}
+14
View File
@@ -0,0 +1,14 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}
+47 -39
View File
@@ -5,7 +5,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
@@ -13,13 +12,17 @@ import (
"github.com/stretchr/testify/assert"
)
var (
discardLogger = NewLogMonitorWriter(io.Discard)
)
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, logMonitor)
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
@@ -52,11 +55,10 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// are all handled successfully, even though they all may ask for the process to .start()
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, logMonitor)
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
defer process.Stop()
var wg sync.WaitGroup
@@ -84,7 +86,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, NewLogMonitor())
process := NewProcess("broken", 1, config, discardLogger, discardLogger)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@@ -109,7 +111,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
defer process.Stop()
// this should take 4 seconds
@@ -151,7 +153,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
@@ -169,6 +171,8 @@ func TestProcess_LowTTLValue(t *testing.T) {
}
// issue #19
// This test makes sure using Process.Stop() does not affect pending HTTP
// requests. All HTTP requests in this test should complete successfully.
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
@@ -176,7 +180,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
process := NewProcess("t", 10, config, discardLogger, discardLogger)
defer process.Stop()
results := map[string]string{
@@ -192,8 +196,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
wg.Add(1)
go func(key string) {
defer wg.Done()
// send a request that should take 5 * 200ms (1 second) to complete
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
// send a request where simple-responder is will wait 300ms before responding
// this will simulate an in-progress request.
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
@@ -209,9 +214,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
}(key)
}
// stop the requests in the middle
// Stop the process while requests are still being processed
go func() {
<-time.After(500 * time.Millisecond)
<-time.After(150 * time.Millisecond)
process.Stop()
}()
@@ -222,39 +227,40 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
}
}
func TestSetState(t *testing.T) {
func TestProcess_SwapState(t *testing.T) {
tests := []struct {
name string
currentState ProcessState
expectedState ProcessState
newState ProcessState
expectedError error
expectedResult ProcessState
}{
{"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped},
{"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting},
{"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady},
{"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady},
{"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping},
{"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed},
{"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed},
{"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown},
{"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown},
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := &Process{
state: test.currentState,
}
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
p.state = test.currentState
err := p.setState(test.newState)
resultState, err := p.swapState(test.expectedState, test.newState)
if err != nil && test.expectedError == nil {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && test.expectedError != nil {
@@ -265,8 +271,8 @@ func TestSetState(t *testing.T) {
}
}
if p.state != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state)
if resultState != test.expectedResult {
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
}
})
}
@@ -277,7 +283,6 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
t.Skip("skipping long shutdown test")
}
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
// make a config where the healthcheck will always fail because port is wrong
@@ -285,13 +290,16 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
// start a goroutine to simulate a shutdown
var wg sync.WaitGroup
go func() {
defer wg.Done()
<-time.After(time.Second * 2)
<-time.After(time.Millisecond * 500)
process.Shutdown()
}()
wg.Add(1)
+257 -58
View File
@@ -5,7 +5,9 @@ import (
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
@@ -13,6 +15,8 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -24,59 +28,96 @@ type ProxyManager struct {
config *Config
currentProcesses map[string]*Process
logMonitor *LogMonitor
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
}
func New(config *Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(),
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
}
if config.LogRequests {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
clientIP,
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
}
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: https://github.com/mostlygeek/llama-swap/issues/42
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.AbortWithStatus(204)
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
@@ -93,6 +134,7 @@ func New(config *Config) *ProxyManager {
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -100,10 +142,16 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
@@ -222,11 +270,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
@@ -259,19 +303,20 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found {
pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel)
return process, nil
}
// stop all running models
pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel)
pm.stopProcesses()
if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
} else {
@@ -282,7 +327,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger)
processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process
}
@@ -342,29 +387,147 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
if process, err := pm.swapModel(model); err != nil {
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
// issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Swap to the requested model
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
}
// Get profile name and model name from the requested model
profileName, modelName := splitRequestedModel(requestedModel)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// Use the modified request for proxying
process.ProxyRequest(c.Writer, modifiedReq)
}
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
@@ -377,6 +540,42 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses()
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, process := range pm.currentProcesses {
// Append the process ID and State (multiple entries if profiles are being used).
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
func splitRequestedModel(requestedModel string) (string, string) {
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
}
+37 -8
View File
@@ -9,7 +9,6 @@ import (
)
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") {
// Set the Content-Type header to text/html
@@ -28,7 +27,7 @@ func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
}
} else {
c.Header("Content-Type", "text/plain")
history := pm.logMonitor.GetHistory()
history := pm.muxLogger.GetHistory()
_, err := c.Writer.Write(history)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
@@ -42,8 +41,14 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher)
@@ -56,7 +61,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// Send history first if not skipped
if !skipHistory {
history := pm.logMonitor.GetHistory()
history := logger.GetHistory()
if len(history) != 0 {
c.Writer.Write(history)
flusher.Flush()
@@ -85,15 +90,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
ch := pm.logMonitor.Subscribe()
defer pm.logMonitor.Unsubscribe(ch)
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := pm.logMonitor.GetHistory()
history := logger.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
@@ -111,3 +122,21 @@ func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
}
}
}
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return logger, nil
}
+404
View File
@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"math/rand"
"mime/multipart"
"net/http"
"net/http/httptest"
"sync"
@@ -304,3 +306,405 @@ func TestProxyManager_Shutdown(t *testing.T) {
}()
wg.Wait()
}
func TestProxyManager_Unload(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
}
proxy := New(config)
proc, err := proxy.swapModel("model1")
assert.NoError(t, err)
assert.NotNil(t, proc)
assert.Len(t, proxy.currentProcesses, 1)
req := httptest.NewRequest("GET", "/unload", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name
func TestProxyManager_StripProfileSlug(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
}
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
} `json:"running"`
}
// Create proxy once for all tests
proxy := New(config)
defer proxy.StopProcesses()
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response RunningResponse
// Check if this is a valid JSON object.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// We should have an empty running array here.
assert.Empty(t, response.Running, "expected no running models")
})
t.Run("single model loaded", func(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// Check if we have a single array element.
assert.Len(t, response.Running, 1)
// Is this the right model?
assert.Equal(t, "model1", response.Running[0].Model)
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
})
t.Run("multiple models via profile", func(t *testing.T) {
// Load more than one model.
for _, model := range []string{"model1", "model2"} {
profileModel := ProcessKeyName("test", model)
reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Simulate the browser call.
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
var response RunningResponse
// The JSON response must be valid.
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// The response should contain 2 models.
assert.Len(t, response.Running, 2)
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
}
// Iterate through the models and check their states as well.
for _, entry := range response.Running {
_, exists := expectedModels[entry.Model]
assert.True(t, exists, "unexpected model %s", entry.Model)
assert.Equal(t, "ready", entry.State)
delete(expectedModels, entry.Model)
}
// Since we deleted each model while testing for its validity we should have no more models in the response.
assert.Empty(t, expectedModels, "unexpected additional models in response")
})
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"},
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
testCases := []struct {
name string
modelInput string
expectModel string
}{
{
name: "With Profile Prefix",
modelInput: "test:TheExpectedModel",
expectModel: "TheExpectedModel", // Profile prefix should be stripped
},
{
name: "Without Profile Prefix",
modelInput: "TheExpectedModel",
expectModel: "TheExpectedModel", // Should remain the same
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tc.modelInput))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
// Generate random content length between 10 and 20
contentLength := rand.Intn(11) + 10 // 10 to 20
content := make([]byte, contentLength)
_, err = fw.Write(content)
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tc.expectModel, response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config)
defer proxy.StopProcesses()
tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
}
tests := []struct {
name string
method string
requestHeaders map[string]string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "OPTIONS with no headers",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
},
},
{
name: "OPTIONS with specific headers",
method: "OPTIONS",
requestHeaders: map[string]string{
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
},
expectedStatus: http.StatusNoContent,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
},
},
{
name: "Non-OPTIONS request",
method: "GET",
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses()
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
for k, v := range tt.requestHeaders {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
for header, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(header))
}
})
}
}
+43
View File
@@ -0,0 +1,43 @@
package proxy
import (
"strings"
)
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+77
View File
@@ -0,0 +1,77 @@
package proxy
import "testing"
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "single valid value",
input: "content-type",
expected: "content-type",
},
{
name: "multiple valid values",
input: "content-type, authorization, x-requested-with",
expected: "content-type, authorization, x-requested-with",
},
{
name: "values with extra spaces",
input: " content-type , authorization ",
expected: "content-type, authorization",
},
{
name: "values with tabs",
input: "content-type,\tauthorization",
expected: "content-type, authorization",
},
{
name: "values with invalid characters",
input: "content-type, auth\n, x-requested-with\r",
expected: "content-type, auth, x-requested-with",
},
{
name: "empty values in list",
input: "content-type,,authorization",
expected: "content-type, authorization",
},
{
name: "leading and trailing commas",
input: ",content-type,authorization,",
expected: "content-type, authorization",
},
{
name: "mixed valid and invalid values",
input: "content-type, \x00invalid, x-requested-with",
expected: "content-type, x-requested-with",
},
{
name: "mixed case values",
input: "Content-Type, my-Valid-Header, Another-hEader",
expected: "Content-Type, my-Valid-Header, Another-hEader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeAccessControlRequestHeaderValues(tt.input)
if got != tt.expected {
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
tt.input, got, tt.expected)
}
})
}
}