Compare commits

...

30 Commits

Author SHA1 Message Date
Benson Wong dea98733c3 proxy: extract metrics for v1/messages (#419) 2025-11-29 23:51:20 -08:00
Benson Wong bccce5fa19 go.mod,ui/package-lock.json: dependency and security updates (#418) 2025-11-29 22:27:22 -08:00
Benson Wong c968da1b73 proxy: add support for anthropic v1/messages api (#417)
* proxy: add support for anthropic v1/messages api
* proxy: restrict loading message to /v1/chat/completions
2025-11-29 22:09:07 -08:00
Ryan Steed a883d68d4f feat: Add support for custom llama.cpp base image and forked llama-swap repositories (#396)
* feat: Add support for custom llama.cpp base image and forked llama-swap repositories

- Introduce BASE_LLAMACPP_IMAGE env var to customize llama.cpp base image
- Introduce LS_REPO env var to customize llama-swap source
- Use GITHUB_REPOSITORY env var to automatically detect forked repos
- Update container tagging to use dynamic repo paths
- Pass build args for BASE_IMAGE and LS_REPO to Containerfile
- Enable flexible release downloads from forked repositories

* chore: quote entire curl options, appease coderabbitai
2025-11-29 20:59:15 -08:00
Ryan Steed b1dec8b735 docker: build both root and non-root container images (#412)
Change the user back to root for containers. Additionally, built a "non-root" labeled container for users who wish to have the additional security of running llama-swap as a lower privileged user.
2025-11-25 10:44:13 -08:00
Nikesh Parajuli 06523d8c1e feat: add platform-specific process attributes support (#411)
Fixes issues on Windows showing new windows for every process llama-swap spawns.
2025-11-24 21:39:56 -08:00
Ryan Steed 86e9b93c37 proxy,ui: add version endpoint and display version info in UI (#395)
- Add /api/version endpoint to ProxyManager that returns build date, commit hash, and version
- Implement SetVersion method to configure version info in ProxyManager
- Add version info fetching to APIProvider and display in ConnectionStatus component
- Include version info in UI context and update dependencies
- Add tests for version endpoint functionality
2025-11-17 10:43:47 -08:00
Ryan Steed 3acace810f proxy: add configurable logging timestamp format (#401)
introduces a new configuration option logTimeFormat that allows customizing the timestamp in log messages using golang's built in time format constants. The default remains no timestamp.
2025-11-16 10:21:59 -08:00
Ryan Steed 554d29e87d feat: enhance model listing to include aliases (#400)
introduce includeAliasesInList as a new configuration setting (default false) that includes aliases in v1/models

Fixes #399
2025-11-15 14:35:26 -08:00
Benson Wong 3567b7df08 Update image in README.md for web UI section 2025-11-08 15:29:37 -08:00
Benson Wong 38738525c9 config.example.yaml: add modeline for schema validation 2025-11-08 15:08:55 -08:00
Benson Wong c0fc858193 Add configuration file JSON schema (#393)
* add json schema for configuration
* add GH action to validate schema
2025-11-08 15:04:14 -08:00
Benson Wong b429349e8a add /ui/ to wol-proxy polling (#388) 2025-11-08 14:16:12 -08:00
Ryan Steed eab2efd7b5 feat: improve llama.cpp base image tag for cpu (#391)
Refactor the container build script to resolve llama.cpp base image for CPU, also tag these builds accordingly.

- For CPU containers, now fetch the latest 'server' tagged llama.cpp image instead of using a generic 'server' tag
- Cleans up the docker build command to use dynamic BASE_TAG variable
- Maintains existing push functionality for built images
2025-11-08 09:56:49 -08:00
Benson Wong 6aedbe121a cmd/wol-proxy: show a loading page for / (#381)
When requesting / wol-proxy will show a loading page that polls /status
every second. When the upstream server is ready the loading page will
refresh causing the actual root page to be displayed
2025-11-03 19:37:06 -08:00
Ryan Steed b24467ab89 fix: update containerfile user/group management commands (#379)
- Replace `addgroup` with `groupadd` for system group creation
- Replace `adduser` with `useradd` for system user creation
- Maintain same functionality while using more standard POSIX commands
2025-11-03 17:17:40 -05:00
Benson Wong 12b69fb718 proxy: recover from panic in Process.statusUpdate (#378)
Process.statusUpdate() panics when it can not write data, usually from a
client disconnect. Since it runs in a goroutine and did not have a
recover() the result was a crash.

ref: https://github.com/mostlygeek/llama-swap/discussions/326#discussioncomment-14856197
2025-11-03 05:30:09 -08:00
Ryan Steed f91a8b2462 refactor: update Containerfile to support non-root user execution and improve security (#368)
Set default container user/group to lower privilege app user 

* refactor: update Containerfile to support non-root user execution and improve security

- Updated LS_VER argument from 89 to 170 to use the latest version
- Added UID/GID arguments with default values of 0 (root) for backward compatibility
- Added USER_HOME environment variable set to /root
- Implemented conditional user/group creation logic that only runs when UID/GID are not 0
- Created necessary directory structure with proper ownership using mkdir and chown commands
- Switched to non-root user execution for improved security posture
- Updated COPY instruction to use --chown flag for proper file ownership

* chore: update containerfile to use non-root user with proper UID/GID

- Changed default UID and GID from 0 (root) to 10001 for security best practices
- Updated USER_HOME from /root to /app to avoid running as root user
2025-10-31 17:01:04 -07:00
Benson Wong a89b803d4a Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response.

Fixes: #366
2025-10-29 00:09:39 -07:00
Benson Wong f852689104 proxy: add panic recovery to Process.ProxyRequest (#363)
Switching to use httputil.ReverseProxy in #342 introduced a possible
panic if a client disconnects while streaming the body. Since llama-swap
does not use http.Server the recover() is not automatically there.

- introduce a recover() in Process.ProxyRequest to recover and log the
  event
- add TestProcess_ReverseProxyPanicIsHandled to reproduce and test the
  fix

fixes: #362
2025-10-25 20:40:05 -07:00
Benson Wong e250e71e59 Include metrics from upstream chat requests (#361)
* proxy: refactor metrics recording

- remove metrics_middleware.go as this wrapper is no longer needed. This
  also eliminiates double body parsing for the modelID
- move metrics parsing to be part of MetricsMonitor
- refactor how metrics are recording in ProxyManager
- add MetricsMonitor tests
- improve mem efficiency of processStreamingResponse
- add benchmarks for MetricsMonitor.addMetrics
- proxy: refactor MetricsMonitor to be more safe handling errors
2025-10-25 17:38:18 -07:00
Benson Wong d18dc26d01 cmd/wol-proxy: tweak logs to show what is causing wake ups (#356)
fix the extra wake ups being caused by wol-proxy

* cmd/wol-proxy: tweak logs to show what is causing wake ups
* cmd/wol-proxy: add skip wakeup
* cmd/wol-proxy: replace ticker with SSE connection
* cmd/wol-proxy: increase scanner buffer size
* cmd/wol-proxy: improve failure tracking
2025-10-25 11:04:31 -07:00
Benson Wong 8357714421 ui: fix avg token/sec calculation on models page (#357)
* ui: use percentiles for token stats
* ui: add histogram of metrics
* update vite to remove security warnings

fixes #355
2025-10-23 22:22:24 -07:00
Benson Wong c07179d6e2 cmd/wol-proxy: add wol-proxy (#352)
add a wake-on-lan proxy for llama-swap. When the target llama-swap server is unreachable it will send hold a request, send a WoL packet and proxy the request when llama-swap is available.
2025-10-20 20:55:02 -07:00
Benson Wong 7ff50631e0 Update README for setup instructions clarity [skip ci] 2025-10-19 14:55:23 -07:00
Benson Wong 9fc0431531 Clean up and Documentation (#347) [skip ci]
* cmd,misc: move misc binaries to cmd/
* docs: add docs and move examples/ there
* misc: remove unused misc/assets dir
* docs: add configuration.md
* update README with better structure

Updates: #334
2025-10-19 14:53:13 -07:00
David Wen Riccardi-Zhu 6516532568 Add optional TLS support (#340)
* Add optional TLS support

Introduce HTTPS support with net/http Server.ListenAndServeTLS.

This should enable the option of serving via HTTPS without a reverse
proxy.

Add two flags:
- tls-cert-file (path to the TLS certificate file)
- tls-key-file (path to the TLS private key file)

Both flags must be supplied together; otherwise exit with error.

If both flags are present, call srv.ListenAndServeTLS.
If not, fall back to the existing srv.ListenAndServe (HTTP); no changes
to existing non‑TLS behavior.
2025-10-15 19:29:02 -07:00
David Wen Riccardi-Zhu d58a8b85bf Refactor to use httputil.ReverseProxy (#342)
* Refactor to use httputil.ReverseProxy

Refactor manual HTTP proxying logic in Process.ProxyRequest to use the standard
library's httputil.ReverseProxy.

* Refactor TestProcess_ForceStopWithKill test

Update to handle behavior with httputil.ReverseProxy.

* Fix gin interface conversion panic
2025-10-13 16:47:04 -07:00
Benson Wong caf9e98b1e Fix race conditions in proxy.Process (#349)
- Fix data races found in proxy.Process by go's race detector. 
- Add data race detection to the CI tests. 

Fixes #348
2025-10-13 16:42:49 -07:00
Benson Wong 539278343b ui: tweak vertical space for mobile (#343) 2025-10-10 10:05:36 -07:00
52 changed files with 3764 additions and 785 deletions
+41
View File
@@ -0,0 +1,41 @@
name: Validate JSON Schema
on:
pull_request:
paths:
- "config-schema.json"
push:
branches:
- main
paths:
- "config-schema.json"
workflow_dispatch:
jobs:
validate-schema:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Validate JSON Schema
run: |
# Check if the file is valid JSON
if ! jq empty config-schema.json 2>/dev/null; then
echo "Error: config-schema.json is not valid JSON"
exit 1
fi
# Validate that it's a valid JSON Schema
# Check for required $schema field
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
echo "Warning: config-schema.json should have a \$schema field"
fi
# Check that it has either properties or definitions
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
fi
echo "✓ config-schema.json is valid"
+1 -1
View File
@@ -11,7 +11,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
## Testing
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors.
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
## Workflow Tasks
+11 -5
View File
@@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt
# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt
go test -count=1 ./proxy/...
go test -race -count=1 ./proxy/...
ui/node_modules:
cd ui && npm install
@@ -61,12 +61,12 @@ windows: ui
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
simple-responder-windows:
@echo "Building simple responder for windows"
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
@@ -86,5 +86,11 @@ release:
echo "tagging new version: $$new_tag"; \
git tag "$$new_tag";
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
# Phony targets
.PHONY: all clean ui mac linux windows simple-responder test test-all test-dev
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
+140 -152
View File
@@ -5,74 +5,166 @@
# llama-swap
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary, a provided docker images or Homebrew.
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
## Features:
- ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
- ✅ On-demand model switching
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- future proof, upgrade your inference servers at any time.
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-server (llama.cpp) supported endpoints:
- ✅ llama-server (llama.cpp) supported endpoints
- `v1/rerank`, `v1/reranking`, `/rerank`
- `/infill` - for code infilling
- `/completion` - for completion endpoint
- ✅ llama-swap custom API endpoints
- ✅ llama-swap API
- `/ui` - web UI
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring
- `/health` - just returns "OK"
-Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- ✅ Full control over server settings per model
- ✅ Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
-Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
### Web UI
llama-swap includes a real time web interface for monitoring logs and controlling models:
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
The Activity Page shows recent requests:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
## Installation
llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. WinGet
4. From release binaries
5. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
# run with a custom configuration and models directory
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
<details>
<summary>
more examples
</summary>
```shell
# pull latest images per platform
docker pull ghcr.io/mostlygeek/llama-swap:cpu
docker pull ghcr.io/mostlygeek/llama-swap:cuda
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
docker pull ghcr.io/mostlygeek/llama-swap:intel
docker pull ghcr.io/mostlygeek/llama-swap:musa
# tagged llama-swap, platform and llama-server version images
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
```
</details>
### Homebrew Install (macOS/Linux)
```shell
brew tap mostlygeek/llama-swap
brew install llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
### WinGet Install (Windows)
> [!NOTE]
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
```shell
# install
C:\> winget install llama-swap
# upgrade
C:\> winget upgrade llama-swap
```
### Pre-built Binaries
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
### Building from source
1. Building requires Go and Node.js (for UI).
1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all`
1. look in the `build/` subdirectory for the llama-swap binary
## Configuration
```yaml
# minimum viable config.yaml
models:
model1:
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
```
That's all you need to get started:
1. `models` - holds all model configurations
2. `model1` - the ID used in API calls
3. `cmd` - the command to run to start the server.
4. `${PORT}` - an automatically assigned port number
Almost all configuration settings are optional and can be added one step at a time:
- Advanced features
- `groups` to run multiple models at once
- `hooks` to run things on startup
- `macros` reusable snippets
- Model customization
- `ttl` to automatically unload models
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `${PORT}` automatic port variables for dynamic port assignment
- `filters` rewrite parts of requests before sending to the upstream server
See the [configuration documentation](docs/configuration.md) for all options.
## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap is managed entirely through a yaml configuration file.
It can be very minimal to start:
```yaml
models:
"qwen2.5":
cmd: |
/path/to/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port ${PORT}
```
However, there are many more capabilities that llama-swap supports:
- `groups` to run multiple models at once
- `ttl` to automatically unload models
- `macros` for reusable snippets
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` for to gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `healthCheckTimeout` to control model startup wait times
- `${PORT}` automatic port variables for dynamic port assignment
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
## Reverse Proxy Configuration (nginx)
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks ServerSent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
@@ -97,111 +189,7 @@ location /v1/chat/completions {
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
## Web UI
llama-swap includes a real time web interface for monitoring logs and models:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
The Activity Page shows recent requests:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
## Installation
llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. From release binaries
4. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker images with llama-swap and llama-server are built nightly.
```shell
# use CPU inference comes with the example config above
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
```
<details>
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
</details>
### Homebrew Install (macOS/Linux)
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
```shell
# Set up tap and install formula
brew tap mostlygeek/llama-swap
brew install llama-swap
# Run llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source
1. Build requires golang and nodejs for the user interface.
1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## Monitoring Logs
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
CLI access is also supported:
## Monitoring Logs on the CLI
```shell
# sends up to the last 10KB of logs
@@ -227,11 +215,11 @@ curl -Ns 'http://host/logs/stream?no-history'
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
## Star History
> [!NOTE]
> ⭐️ Star this project to help others discover it!
> ⭐️ Star this project to help others discover it!
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
+27
View File
@@ -0,0 +1,27 @@
# wol-proxy
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
## Usage
```shell
# minimal
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
# everything
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
# use debug log level
-log debug \
# altenerative listening port
-listen localhost:9999 \
# seconds to hold requests waiting for upstream to be ready
-timeout 30
```
## API
`GET /status` - that's it. Everything else is proxied to the upstream server.
+64
View File
@@ -0,0 +1,64 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Loading...</title>
<style>
body {
font-family: sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: #f5f5f5;
}
.loader {
text-align: center;
}
.stats {
font-size: 18px;
color: #333;
margin: 20px 0;
}
.stats-label {
color: #666;
font-size: 14px;
}
</style>
</head>
<body>
<div class="loader">
<p>Waking up upstream server...</p>
<div class="stats">
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
<div><span id="attempts">&nbsp;</span></div>
</div>
</div>
<script>
var startTime = Date.now();
var attempts = 0;
setInterval(function() {
var elapsed = (Date.now() - startTime) / 1000;
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
}, 100);
// Check status every second
setInterval(function() {
attempts++;
var dots = '.'.repeat((attempts % 10) || 10);
document.getElementById('attempts').textContent = dots;
fetch('/status')
.then(function(r) { return r.text(); })
.then(function(t) {
if (t.indexOf('status: ready') !== -1) {
location.reload();
}
})
.catch(function() {});
}, 1000);
</script>
</body>
</html>
+333
View File
@@ -0,0 +1,333 @@
package main
import (
"bufio"
"context"
_ "embed"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"time"
)
//go:embed index.html
var loadingPageHTML string
var (
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
flagListen = flag.String("listen", ":8080", "listen address to listen on")
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
)
func main() {
flag.Parse()
switch *flagLog {
case "debug":
slog.SetLogLoggerLevel(slog.LevelDebug)
case "info":
slog.SetLogLoggerLevel(slog.LevelInfo)
case "warn":
slog.SetLogLoggerLevel(slog.LevelWarn)
case "error":
slog.SetLogLoggerLevel(slog.LevelError)
default:
slog.Error("invalid log level", "logLevel", *flagLog)
return
}
// Validate flags
if *flagListen == "" {
slog.Error("listen address is required")
return
}
if *flagMac == "" {
slog.Error("mac address is required")
return
}
if *flagTimeout < 1 {
slog.Error("timeout must be greater than 0")
return
}
var upstreamURL *url.URL
var err error
// validate mac address
if _, err = net.ParseMAC(*flagMac); err != nil {
slog.Error("invalid mac address", "error", err)
return
}
if *flagUpstream == "" {
slog.Error("upstream proxy address is required")
return
} else {
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
if err != nil {
slog.Error("error parsing upstream url", "error", err)
return
}
}
proxy := newProxy(upstreamURL)
server := &http.Server{
Addr: *flagListen,
Handler: proxy,
}
// start the server
go func() {
slog.Info("server starting on", "address", *flagListen)
if err := server.ListenAndServe(); err != nil {
slog.Error("error starting server", "error", err)
}
}()
// graceful shutdown
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
<-ctx.Done()
server.Close()
}
type upstreamStatus string
const (
notready upstreamStatus = "not ready"
ready upstreamStatus = "ready"
)
type proxyServer struct {
upstreamProxy *httputil.ReverseProxy
failCount int
statusMutex sync.RWMutex
status upstreamStatus
}
func newProxy(url *url.URL) *proxyServer {
p := httputil.NewSingleHostReverseProxy(url)
proxy := &proxyServer{
upstreamProxy: p,
status: notready,
failCount: 0,
}
// start a goroutine to monitor upstream status via SSE
go func() {
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
client := &http.Client{
Timeout: 0, // No timeout for SSE connection
}
waitDuration := 10 * time.Second
for {
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
req, err := http.NewRequest("GET", eventsUrl, nil)
if err != nil {
slog.Warn("failed to create SSE request", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(waitDuration)
continue
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
resp, err := client.Do(req)
if err != nil {
slog.Error("failed to connect to SSE endpoint", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
if resp.StatusCode != http.StatusOK {
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
// Successfully connected to SSE endpoint
slog.Info("connected to SSE endpoint, upstream ready")
proxy.setStatus(ready)
proxy.resetFailures()
// Read from the SSE stream to detect disconnection
scanner := bufio.NewScanner(resp.Body)
// use a fairly large buffer to avoid scanner errors when reading large SSE events
buf := make([]byte, 0, 1024*1024*2)
scanner.Buffer(buf, 1024*1024*2)
events := 0
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
fmt.Print("Events: ")
}
for scanner.Scan() {
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
// Just read the events to keep connection alive
// We don't need to process the event data
events++
fmt.Printf("%d, ", events)
}
}
fmt.Println()
if err := scanner.Err(); err != nil {
slog.Error("error reading from SSE stream", "error", err)
}
// Connection closed or error occurred
_ = resp.Body.Close()
slog.Info("SSE connection closed, upstream not ready")
proxy.setStatus(notready)
proxy.incFail(1)
// Wait before reconnecting
time.Sleep(waitDuration)
}
}()
return proxy
}
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && r.URL.Path == "/status" {
status := string(p.getStatus())
failCount := p.getFailures()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
fmt.Fprintf(w, "status: %s\n", status)
fmt.Fprintf(w, "failures: %d\n", failCount)
return
}
if p.getStatus() == notready {
path := r.URL.Path
if strings.HasPrefix(path, "/api/events") {
slog.Debug("Skipping wake up", "req", path)
w.WriteHeader(http.StatusNoContent)
return
}
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
if err := sendMagicPacket(*flagMac); err != nil {
slog.Warn("failed to send magic WoL packet", "error", err)
}
// For root or UI path requests, return loading page with status polling
// the web page will do the polling and redirect when ready
if path == "/" || strings.HasPrefix(path, "/ui/") {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, loadingPageHTML)
return
}
ticker := time.NewTicker(250 * time.Millisecond)
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
defer cancel()
loop:
for {
select {
case <-timeout.Done():
slog.Info("timeout waiting for upstream to be ready")
http.Error(w, "timeout", http.StatusRequestTimeout)
return
case <-ticker.C:
if p.getStatus() == ready {
ticker.Stop()
break loop
}
}
}
}
p.upstreamProxy.ServeHTTP(w, r)
}
func (p *proxyServer) getStatus() upstreamStatus {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.status
}
func (p *proxyServer) setStatus(status upstreamStatus) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.status = status
}
func (p *proxyServer) incFail(num int) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount += num
}
func (p *proxyServer) getFailures() int {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.failCount
}
func (p *proxyServer) resetFailures() {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount = 0
}
func sendMagicPacket(macAddr string) error {
hwAddr, err := net.ParseMAC(macAddr)
if err != nil {
return err
}
if len(hwAddr) != 6 {
return errors.New("invalid MAC address")
}
// Create the magic packet.
packet := make([]byte, 102)
// Add 6 bytes of 0xFF.
for i := 0; i < 6; i++ {
packet[i] = 0xFF
}
// Repeat the MAC address 16 times.
for i := 1; i <= 16; i++ {
copy(packet[i*6:], hwAddr)
}
// Send the packet using UDP.
addr := net.UDPAddr{
IP: net.IPv4bcast,
Port: 9,
}
conn, err := net.DialUDP("udp", nil, &addr)
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write(packet)
return err
}
+278
View File
@@ -0,0 +1,278 @@
{
"$schema": "https://json-schema.org/draft-07/schema#",
"$id": "llama-swap-config-schema.json",
"title": "llama-swap configuration",
"description": "Configuration file for llama-swap",
"type": "object",
"required": [
"models"
],
"definitions": {
"macros": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string",
"minLength": 0,
"maxLength": 1024
},
{
"type": "number"
},
{
"type": "boolean"
}
]
},
"propertyNames": {
"type": "string",
"minLength": 1,
"maxLength": 64,
"pattern": "^[a-zA-Z0-9_-]+$",
"not": {
"enum": [
"PORT",
"MODEL_ID"
]
}
},
"default": {},
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
}
},
"properties": {
"healthCheckTimeout": {
"type": "integer",
"minimum": 15,
"default": 120,
"description": "Number of seconds to wait for a model to be ready to serve requests."
},
"logLevel": {
"type": "string",
"enum": [
"debug",
"info",
"warn",
"error"
],
"default": "info",
"description": "Sets the logging value. Valid values: debug, info, warn, error."
},
"logTimeFormat": {
"type": "string",
"enum": [
"",
"ansic",
"unixdate",
"rubydate",
"rfc822",
"rfc822z",
"rfc850",
"rfc1123",
"rfc1123z",
"rfc3339",
"rfc3339nano",
"kitchen",
"stamp",
"stampmilli",
"stampmicro",
"stampnano"
],
"default": "",
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
},
"metricsMaxInMemory": {
"type": "integer",
"default": 1000,
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
},
"startPort": {
"type": "integer",
"default": 5800,
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
},
"sendLoadingState": {
"type": "boolean",
"default": false,
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
},
"includeAliasesInList": {
"type": "boolean",
"default": false,
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
},
"macros": {
"$ref": "#/definitions/macros"
},
"models": {
"type": "object",
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
"additionalProperties": {
"type": "object",
"required": [
"cmd"
],
"properties": {
"macros": {
"$ref": "#/definitions/macros"
},
"cmd": {
"type": "string",
"minLength": 1,
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
},
"cmdStop": {
"type": "string",
"default": "",
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
},
"name": {
"type": "string",
"default": "",
"maxLength": 128,
"description": "Display name for the model. Used in v1/models API response."
},
"description": {
"type": "string",
"default": "",
"maxLength": 1024,
"description": "Description for the model. Used in v1/models API response."
},
"env": {
"type": "array",
"items": {
"type": "string",
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
},
"default": [],
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
},
"proxy": {
"type": "string",
"default": "http://localhost:${PORT}",
"format": "uri",
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
},
"aliases": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Alternative model names for this configuration. Must be unique globally."
},
"checkEndpoint": {
"type": "string",
"default": "/health",
"pattern": "^/.*$|^none$",
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
},
"ttl": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
},
"useModelName": {
"type": "string",
"default": "",
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
},
"filters": {
"type": "object",
"properties": {
"stripParams": {
"type": "string",
"default": "",
"pattern": "^[a-zA-Z0-9_, ]*$",
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
}
},
"additionalProperties": false,
"default": {},
"description": "Dictionary of filter settings. Only stripParams is supported."
},
"metadata": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
},
"concurrencyLimit": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
},
"sendLoadingState": {
"type": "boolean",
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
},
"unlisted": {
"type": "boolean",
"default": false,
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
}
}
}
},
"groups": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"members"
],
"properties": {
"swap": {
"type": "boolean",
"default": true,
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
},
"exclusive": {
"type": "boolean",
"default": true,
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
},
"persistent": {
"type": "boolean",
"default": false,
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
},
"members": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
}
}
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
},
"hooks": {
"type": "object",
"properties": {
"on_startup": {
"type": "object",
"properties": {
"preload": {
"type": "array",
"items": {
"type": "string"
},
"default": [],
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
}
},
"additionalProperties": false,
"description": "Actions to perform on startup. Only supported action is preload."
}
},
"additionalProperties": false,
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
}
}
}
+35 -7
View File
@@ -1,3 +1,6 @@
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example
# -------------------------------------
#
@@ -23,6 +26,14 @@ healthCheckTimeout: 500
# - Valid log levels: debug, info, warn, error
logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
@@ -35,6 +46,20 @@ metricsMaxInMemory: 1000
# - it is automatically incremented for every model that uses it
startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
@@ -64,7 +89,6 @@ macros:
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"llama":
# macros: a dictionary of string substitutions specific to this model
@@ -184,6 +208,10 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
@@ -286,10 +314,10 @@ hooks:
# - optional, default: empty dictionary
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
- "llama"
+46 -22
View File
@@ -20,36 +20,60 @@ if [[ -z "$GITHUB_TOKEN" ]]; then
exit 1
fi
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
# variable, this permits testing with forked llama.cpp repositories
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
# to enable easy container builds on forked repos
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
# cpu only containers just use the server tag
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
BASE_TAG=server-${LCPP_TAG}
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
BASE_TAG=server-${ARCH}-${LCPP_TAG}
fi
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
for CONTAINER_TYPE in non-root root; do
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
USER_UID=0
USER_GID=0
USER_HOME=/root
if [ "$CONTAINER_TYPE" == "non-root" ]; then
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
USER_UID=10001
USER_GID=10001
USER_HOME=/app
fi
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
done
+31 -7
View File
@@ -1,16 +1,40 @@
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
FROM ${BASE_IMAGE}:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
ARG LS_VER=170
ARG LS_REPO=mostlygeek/llama-swap
# Set default UID/GID arguments
ARG UID=10001
ARG GID=10001
ARG USER_HOME=/app
# Add user/group
ENV HOME=$USER_HOME
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
groupadd --system --gid $GID app; \
fi; \
useradd --system --uid $UID --gid $GID \
--home $USER_HOME app; \
fi
# Handle paths
RUN mkdir --parents $HOME /app
RUN chown --recursive $UID:$GID $HOME /app
# Switch user
USER $UID:$GID
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
COPY config.example.yaml /app/config.yaml
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
+386
View File
@@ -0,0 +1,386 @@
# config.yaml
llama-swap is designed to be very simple: one binary, one configuration file.
## minimal viable config
```yaml
models:
model1:
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
```
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
```yaml
models:
model1:
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
model2:
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
model3:
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
```
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature.
## Advanced control with `cmd`
llama-swap is also about customizability. You can use any CLI flag available:
```yaml
models:
model1:
cmd: | # support for multi-line
llama-server --PORT ${PORT} -m /path/to/model.gguf
--ctx-size 8192
--jinja
--cache-type-k q8_0
--cache-type-v q8_0
```
## Support for any OpenAI API compatible server
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
```yaml
models:
"Q3-30B-CODER-VLLM":
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
# cmdStop provides a reliable way to stop containers
cmdStop: docker stop vllm-coder
cmd: |
docker run --init --rm --name vllm-coder
--runtime=nvidia --gpus '"device=2,3"'
--shm-size=16g
-v /mnt/nvme/vllm-cache:/root/.cache
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
vllm/vllm-openai:v0.10.0
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
--served-model-name "Q3-30B-CODER-VLLM"
--enable-expert-parallel
--swap-space 16
--max-num-seqs 512
--max-model-len 65536
--max-seq-len-to-capture 65536
--gpu-memory-utilization 0.9
--tensor-parallel-size 2
--trust-remote-code
```
## Many more features..
llama-swap supports many more features to customize how you want to manage your environment.
| Feature | Description |
| --------- | ---------------------------------------------- |
| `ttl` | automatic unloading of models after a timeout |
| `macros` | reusable snippets to use in configurations |
| `groups` | run multiple models at a time |
| `hooks` | event driven functionality |
| `env` | define environment variables per model |
| `aliases` | serve a model with different names |
| `filters` | modify requests before sending to the upstream |
| `...` | And many more tweaks |
## Full Configuration Example
> [!NOTE]
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
```yaml
# llama-swap YAML configuration example
# -------------------------------------
#
# 💡 Tip - Use an LLM with this file!
# ====================================
# This example configuration is written to be LLM friendly. Try
# copying this file into an LLM and asking it to explain or generate
# sections for you.
# ====================================
# Usage notes:
# - Below are all the available configuration options for llama-swap.
# - Settings noted as "required" must be in your configuration file
# - Settings noted as "optional" can be omitted
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
# - optional, default: 120
# - minimum value is 15 seconds, anything less will be set to this value
healthCheckTimeout: 500
# logLevel: sets the logging value
# - optional, default: info
# - Valid log levels: debug, info, warn, error
logLevel: info
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
# - it is automatically incremented for every model that uses it
startPort: 10001
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
# - useful for reducing common configuration settings
# - macro names are strings and must be less than 64 characters
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
"default_ctx": 4096
# Example of macro-in-macro usage. macros can contain other macros
# but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}"
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
# - model settings have default values that are used if they are not defined here
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"llama":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
# - model level macros follow the same rules as global macros
macros:
"default_ctx": 16384
"temp": 0.7
# cmd: the command to run to start the inference server.
# - required
# - it is just a string, similar to what you would run on the CLI
# - using `|` allows for comments in the command, these will be parsed out
# - macros can be used within cmd
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf
--ctx-size ${default_ctx}
--temperature ${temp}
# name: a display name for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B"
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing"
# env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array
# - each value is a single string
# - in the format: ENV_NAME=value
env:
- "CUDA_VISIBLE_DEVICES=0,1,2"
# proxy: the URL where llama-swap routes API requests
# - optional, default: http://localhost:${PORT}
# - if you used ${PORT} in cmd this can be omitted
# - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# checkEndpoint: URL path to check if the server is ready
# - optional, default: /health
# - endpoint is expected to return an HTTP 200 response
# - all requests wait until the endpoint is ready or fails
# - use "none" to skip endpoint health checking
checkEndpoint: /custom-endpoint
# ttl: automatically unload the model after ttl seconds
# - optional, default: 0
# - ttl values must be a value greater than 0
# - a value of 0 disables automatic unloading of the model
ttl: 60
# useModelName: override the model name that is sent to upstream server
# - optional, default: ""
# - useful for when the upstream server expects a specific model name that
# is different from the model's ID
useModelName: "qwen:qwq"
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only stripParams is currently supported
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for server side enforcement of sampling parameters
# - the `model` parameter can never be removed
# - can be any JSON key in the request body
# - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k"
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
# - metadata is only passed through in /v1/models responses
metadata:
# port will remain an integer
port: ${PORT}
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
a_list:
- 1
- 1.23
- "macros are OK in list and dictionary types: ${MODEL_ID}"
an_obj:
a: "1"
b: 2
# objects can contain complex types with macro substitution
# becomes: c: [0.7, false, "model: llama"]
c: ["${temp}", false, "model: ${MODEL_ID}"]
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0
# - useful for limiting the number of active parallel requests a model can process
# - must be set per model
# - any number greater than 0 will override the internal default value of 10
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
# - recommended to be omitted and the default used
concurrencyLimit: 0
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
# - optional, default: false
# - unlisted models do not show up in /v1/models api requests
# - can be requested as normal through all apis
unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker example:
# container runtimes like Docker and Podman can be used reliably with
# a combination of cmd, cmdStop, and ${MODEL_ID}
"docker-llama":
proxy: "http://127.0.0.1:${PORT}"
cmd: |
docker run --name ${MODEL_ID}
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggml-org/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# cmdStop: command to run to stop the model gracefully
# - optional, default: ""
# - useful for stopping commands managed by another system
# - the upstream's process id is available in the ${PID} macro
#
# When empty, llama-swap has this default behaviour:
# - on POSIX systems: a SIGTERM signal is sent
# - on Windows, calls taskkill to stop the process
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provides advanced controls over model swapping behaviour
# - using groups some models can be kept loaded indefinitely, while others are swapped out
# - model IDs must be defined in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
#
# NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap: controls the model swapping behaviour in within the group
# - optional, default: true
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive: controls how the group affects other groups
# - optional, default: true
# - true: causes all other groups to unload when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
# required
members:
- "llama"
- "qwen-unlisted"
# Example:
# - in group2 all models can run at the same time
# - when a different group is loaded it causes all running models in this group to unload
"group2":
swap: false
# exclusive: false does not unload other groups when a model in group2 is requested
# - the models in group2 will be loaded but will not unload any other groups
exclusive: false
members:
- "docker-llama"
- "modelA"
- "modelB"
# Example:
# - a persistent group, prevents other groups from unloading it
"forever":
# persistent: prevents over groups from unloading the models in this group
# - optional, default: false
# - does not affect individual model behaviour
persistent: true
# set swap/exclusive to false to prevent swapping inside the group
# and the unloading of other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
# - the only supported hook is on_startup
hooks:
# on_startup: a dictionary of actions to perform on startup
# - optional, default: empty dictionary
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
```
+5 -5
View File
@@ -1,6 +1,6 @@
module github.com/mostlygeek/llama-swap
go 1.23.0
go 1.25.4
require (
github.com/billziss-gh/golib v0.2.0
@@ -37,9 +37,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+8 -8
View File
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
+35 -5
View File
@@ -28,7 +28,9 @@ var (
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
listenStr := flag.String("listen", "", "listen ip/port")
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
keyFile := flag.String("tls-key-file", "", "TLS key file")
showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
@@ -55,6 +57,23 @@ func main() {
gin.SetMode(gin.ReleaseMode)
}
// Validate TLS flags.
var useTLS = (*certFile != "" && *keyFile != "")
if (*certFile != "" && *keyFile == "") ||
(*certFile == "" && *keyFile != "") {
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
os.Exit(1)
}
// Set default ports.
if *listenStr == "" {
defaultPort := ":8080"
if useTLS {
defaultPort = ":8443"
}
listenStr = &defaultPort
}
// Setup channels for server management
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1)
@@ -76,7 +95,9 @@ func main() {
fmt.Println("Configuration Changed")
currentPM.Shutdown()
srv.Handler = proxy.New(conf)
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
srv.Handler = newPM
fmt.Println("Configuration Reloaded")
// wait a few seconds and tell any UI to reload
@@ -91,7 +112,9 @@ func main() {
fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1)
}
srv.Handler = proxy.New(conf)
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
srv.Handler = newPM
}
}
@@ -167,9 +190,16 @@ func main() {
}()
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
var err error
if useTLS {
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
err = srv.ListenAndServeTLS(*certFile, *keyFile)
} else {
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
log.Fatalf("Fatal server error: %v\n", err)
}
}()
Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

+23
View File
@@ -3,6 +3,7 @@ package config
import (
"fmt"
"io"
"net/url"
"os"
"regexp"
"runtime"
@@ -112,6 +113,7 @@ type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
LogTimeFormat string `yaml:"logTimeFormat"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
@@ -128,6 +130,12 @@ type Config struct {
// hooks, see: #209
Hooks HooksConfig `yaml:"hooks"`
// send loading state in reasoning
SendLoadingState bool `yaml:"sendLoadingState"`
// present aliases to /v1/models OpenAI API listing
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
}
func (c *Config) RealModelName(search string) (string, bool) {
@@ -168,6 +176,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
MetricsMaxInMemory: 1000,
}
err = yaml.Unmarshal(data, &config)
@@ -342,6 +351,20 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// Validate the proxy URL.
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf(
"model %s: invalid proxy URL: %w", modelId, err,
)
}
// if sendLoadingState is nil, set it to the global config value
// see #366
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState // copy it
modelConfig.SendLoadingState = &v
}
config.Models[modelId] = modelConfig
}
+33 -24
View File
@@ -58,6 +58,7 @@ models:
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
assert.Equal(t, "", config.LogTimeFormat)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
@@ -160,9 +161,12 @@ groups:
t.Fatalf("Failed to load config: %v", err)
}
modelLoadingState := false
expected := Config{
LogLevel: "info",
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
},
@@ -171,36 +175,41 @@ groups:
Preload: []string{"model1", "model2"},
},
},
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
SendLoadingState: &modelLoadingState,
},
"model2": {
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
},
},
HealthCheckTimeout: 15,
+35 -26
View File
@@ -55,6 +55,7 @@ models:
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
assert.Equal(t, "", config.LogTimeFormat)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
@@ -152,44 +153,52 @@ groups:
t.Fatalf("Failed to load config: %v", err)
}
modelLoadingState := false
expected := Config{
LogLevel: "info",
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
},
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
},
"model2": {
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
},
},
HealthCheckTimeout: 15,
+3
View File
@@ -35,6 +35,9 @@ type ModelConfig struct {
// Metadata: see #264
// Arbitrary metadata that can be exposed through the API
Metadata map[string]any `yaml:"metadata"`
// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
+21 -1
View File
@@ -50,5 +50,25 @@ models:
}
})
}
}
func TestConfig_ModelSendLoadingState(t *testing.T) {
content := `
sendLoadingState: true
models:
model1:
cmd: path/to/cmd --port ${PORT}
sendLoadingState: false
model2:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.True(t, config.SendLoadingState)
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
assert.False(t, *config.Models["model1"].SendLoadingState)
}
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
assert.True(t, *config.Models["model2"].SendLoadingState)
}
}
+21 -6
View File
@@ -7,6 +7,7 @@ import (
"io"
"os"
"sync"
"time"
"github.com/mostlygeek/llama-swap/event"
)
@@ -32,6 +33,9 @@ type LogMonitor struct {
// logging levels
level LogLevel
prefix string
// timestamps
timeFormat string
}
func NewLogMonitor() *LogMonitor {
@@ -40,11 +44,12 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
timeFormat: "",
}
}
@@ -106,12 +111,22 @@ func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.level = level
}
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
w.mu.Lock()
defer w.mu.Unlock()
w.timeFormat = timeFormat
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
timestamp := ""
if w.timeFormat != "" {
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
}
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
+29
View File
@@ -3,8 +3,10 @@ package proxy
import (
"bytes"
"io"
"strings"
"sync"
"testing"
"time"
)
func TestLogMonitor(t *testing.T) {
@@ -84,3 +86,30 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
t.Errorf("Expected history to be %q, got %q", expected, history)
}
}
func TestWrite_LogTimeFormat(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
// Enable timestamps
lm.timeFormat = time.RFC3339
// Write the message to the LogMonitor
lm.Info("Hello, World!")
// Get the history from the LogMonitor
history := lm.GetHistory()
timestamp := ""
fields := strings.Fields(string(history))
if len(fields) > 0 {
timestamp = fields[0]
} else {
t.Fatalf("Cannot extract string from history")
}
_, err := time.Parse(time.RFC3339, timestamp)
if err != nil {
t.Fatalf("Cannot find timestamp: %v", err)
}
}
-184
View File
@@ -1,184 +0,0 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
// isStreaming bool
startTime time.Time
}
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
// check for streaming response
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
writer.metricsRecorder.processStreamingResponse(writer.body)
} else {
writer.metricsRecorder.processNonStreamingResponse(writer.body)
}
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return false
}
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
+222 -15
View File
@@ -1,12 +1,18 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/tidwall/gjson"
)
// TokenMetrics represents parsed token statistics from llama-server logs
@@ -31,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
}
// MetricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct {
// metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
maxMetrics int
nextID int
logger *LogMonitor
}
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
}
mp := &MetricsMonitor{
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
mp := &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
}
@@ -53,7 +56,7 @@ func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
mp.mu.Lock()
defer mp.mu.Unlock()
@@ -66,8 +69,8 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
event.Emit(TokenMetricsEvent{Metrics: metric})
}
// GetMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
// getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
mp.mu.RLock()
defer mp.mu.RUnlock()
@@ -76,9 +79,213 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
return result
}
// GetMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
// getMetricsJSON returns metrics as JSON
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
}
// wrapHandler wraps the proxy handler to extract token metrics
// if wrapHandler returns an error it is safe to assume that no
// data was sent to the client
func (mp *metricsMonitor) wrapHandler(
modelID string,
writer gin.ResponseWriter,
request *http.Request,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
recorder := newBodyCopier(writer)
if err := next(modelID, recorder, request); err != nil {
return err
}
// after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
return nil
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics skipped, empty body")
return nil
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
}
} else {
if gjson.ValidBytes(body) {
parsed := gjson.ParseBytes(body)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
if usage.Exists() || timings.Exists() {
if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
}
}
} else {
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
}
}
return nil
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.
// Start from the end of the body and scan backwards for newlines
pos := len(body)
for pos > 0 {
// Find the previous newline (or start of body)
lineStart := bytes.LastIndexByte(body[:pos], '\n')
if lineStart == -1 {
lineStart = 0
} else {
lineStart++ // Move past the newline
}
line := bytes.TrimSpace(body[lineStart:pos])
pos = lineStart - 1 // Move position before the newline for next iteration
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
parsed := gjson.ParseBytes(data)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
if usage.Exists() || timings.Exists() {
return parseMetrics(modelID, start, usage, timings)
}
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(start).Milliseconds())
if usage.Exists() {
if pt := usage.Get("prompt_tokens"); pt.Exists() {
// v1/chat/completions
inputTokens = int(pt.Int())
} else if it := usage.Get("input_tokens"); it.Exists() {
// v1/messages
inputTokens = int(it.Int())
}
if ct := usage.Get("completion_tokens"); ct.Exists() {
// v1/chat/completions
outputTokens = int(ct.Int())
} else if ot := usage.Get("output_tokens"); ot.Exists() {
outputTokens = int(ot.Int())
}
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
cachedTokens = int(ct.Int())
}
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(timings.Get("prompt_n").Int())
outputTokens = int(timings.Get("predicted_n").Int())
promptPerSecond = timings.Get("prompt_per_second").Float()
tokensPerSecond = timings.Get("predicted_per_second").Float()
durationMs = int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
return TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
}, nil
}
// responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
gin.ResponseWriter
body *bytes.Buffer
tee io.Writer
start time.Time
}
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
bodyBuffer := &bytes.Buffer{}
return &responseBodyCopier{
ResponseWriter: w,
body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer),
}
}
func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
}
// Single write operation that writes to both the response and buffer
return w.tee.Write(b)
}
func (w *responseBodyCopier) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseBodyCopier) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *responseBodyCopier) StartTime() time.Time {
return w.start
}
+693
View File
@@ -0,0 +1,693 @@
package proxy
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/stretchr/testify/assert"
)
func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"})
}
metrics := mm.getMetrics()
assert.Equal(t, 5, len(metrics))
for i := 0; i < 5; i++ {
assert.Equal(t, i, metrics[i].ID)
}
})
t.Run("respects max metrics limit", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 3)
// Add 5 metrics
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{
Model: "model",
InputTokens: i,
})
}
metrics := mm.getMetrics()
assert.Equal(t, 3, len(metrics))
// Should keep the last 3 metrics (IDs 2, 3, 4)
assert.Equal(t, 2, metrics[0].ID)
assert.Equal(t, 3, metrics[1].ID)
assert.Equal(t, 4, metrics[2].ID)
})
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
receivedEvent := make(chan TokenMetricsEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) {
receivedEvent <- e
})
defer cancel()
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
select {
case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event")
}
})
}
func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns empty slice when no metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metrics := mm.getMetrics()
assert.NotNil(t, metrics)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"})
metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics()
// Verify we got copies
assert.Equal(t, 2, len(metrics1))
assert.Equal(t, 2, len(metrics2))
// Modify the returned slice shouldn't affect the original
metrics1[0].Model = "modified"
metrics3 := mm.getMetrics()
assert.Equal(t, "model1", metrics3[0].Model)
})
}
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
assert.NotNil(t, jsonData)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{
Model: "model1",
InputTokens: 100,
OutputTokens: 50,
TokensPerSecond: 25.5,
})
mm.addMetrics(TokenMetrics{
Model: "model2",
InputTokens: 200,
OutputTokens: 100,
TokensPerSecond: 30.0,
})
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 2, len(metrics))
assert.Equal(t, "model1", metrics[0].Model)
assert.Equal(t, "model2", metrics[1].Model)
})
}
func TestMetricsMonitor_WrapHandler(t *testing.T) {
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("successful request with timings data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0,
"cache_n": 20
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
})
t.Run("streaming request with SSE format", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Note: SSE format requires proper line breaks - each data line followed by blank line
responseBody := `data: {"choices":[{"text":"Hello"}]}
data: {"choices":[{"text":" World"}]}
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens)
})
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("empty response body does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("not valid json"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("next handler error is propagated", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
expectedErr := assert.AnError
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
return expectedErr
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"result": "ok"}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
}
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
t.Run("captures response body", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
testData := []byte("test response body")
n, err := copier.Write(testData)
assert.NoError(t, err)
assert.Equal(t, len(testData), n)
assert.Equal(t, testData, copier.body.Bytes())
assert.Equal(t, string(testData), rec.Body.String())
})
t.Run("sets start time on first write", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
assert.True(t, copier.StartTime().IsZero())
copier.Write([]byte("test"))
assert.False(t, copier.StartTime().IsZero())
})
t.Run("preserves headers", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.Header().Set("X-Test", "value")
assert.Equal(t, "value", rec.Header().Get("X-Test"))
})
t.Run("preserves status code", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.WriteHeader(http.StatusCreated)
// Gin's ResponseWriter tracks status internally
assert.Equal(t, http.StatusCreated, copier.Status())
})
}
func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000)
var wg sync.WaitGroup
numGoroutines := 10
metricsPerGoroutine := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{
Model: "test-model",
InputTokens: id*1000 + j,
OutputTokens: j,
})
}
}(i)
}
wg.Wait()
metrics := mm.getMetrics()
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
})
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 100)
done := make(chan bool)
// Writer goroutine
go func() {
for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"})
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
// Multiple reader goroutines
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 20; j++ {
_ = mm.getMetrics()
_, _ = mm.getMetricsJSON()
time.Sleep(2 * time.Millisecond)
}
}()
}
<-done
wg.Wait()
// Final check
metrics := mm.getMetrics()
assert.Equal(t, 50, len(metrics))
})
}
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
t.Run("prefers timings over usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Timings should take precedence over usage
responseBody := `{
"usage": {
"prompt_tokens": 50,
"completion_tokens": 25
},
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles missing cache_n in timings", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
})
}
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Metrics should be found in the last data line before [DONE]
responseBody := `data: {"choices":[{"text":"First"}]}
data: {"choices":[{"text":"Second"}]}
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `data: not json
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("handles empty streaming response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := ``
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
// Empty body should not trigger WrapHandler processing
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
}
// Benchmark tests
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
+370 -66
View File
@@ -2,16 +2,18 @@ package proxy
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -39,11 +41,13 @@ const (
)
type Process struct {
ID string
config config.ModelConfig
cmd *exec.Cmd
ID string
config config.ModelConfig
cmd *exec.Cmd
reverseProxy *httputil.ReverseProxy
// PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc
// closed when command exits
@@ -55,12 +59,14 @@ type Process struct {
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
lastRequestHandledMutex sync.RWMutex
lastRequestHandled time.Time
stateMutex sync.RWMutex
state ProcessState
inFlightRequests sync.WaitGroup
inFlightRequests sync.WaitGroup
inFlightRequestsCount atomic.Int32
// used to block on multiple start() calls
waitStarting sync.WaitGroup
@@ -81,10 +87,29 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
concurrentLimit = config.ConcurrencyLimit
}
// Setup the reverse proxy.
proxyURL, err := url.Parse(config.Proxy)
if err != nil {
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
}
var reverseProxy *httputil.ReverseProxy
if proxyURL != nil {
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
reverseProxy.ModifyResponse = func(resp *http.Response) error {
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
}
return &Process{
ID: ID,
config: config,
cmd: nil,
reverseProxy: reverseProxy,
cancelUpstream: nil,
processLogger: processLogger,
proxyLogger: proxyLogger,
@@ -107,6 +132,20 @@ func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
func (p *Process) setLastRequestHandled(t time.Time) {
p.lastRequestHandledMutex.Lock()
defer p.lastRequestHandledMutex.Unlock()
p.lastRequestHandled = t
}
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
func (p *Process) getLastRequestHandled() time.Time {
p.lastRequestHandledMutex.RLock()
defer p.lastRequestHandledMutex.RUnlock()
return p.lastRequestHandled
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -130,6 +169,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
}
p.state = newState
// Atomically increment waitStarting when entering StateStarting
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
if newState == StateStarting {
p.waitStarting.Add(1)
}
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil
@@ -158,6 +204,15 @@ func (p *Process) CurrentState() ProcessState {
return p.state
}
// forceState forces the process state to the new state with mutex protection.
// This should only be used in exceptional cases where the normal state transition
// validation via swapState() cannot be used.
func (p *Process) forceState(newState ProcessState) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.state = newState
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
@@ -191,7 +246,7 @@ func (p *Process) start() error {
}
}
p.waitStarting.Add(1)
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
@@ -201,8 +256,12 @@ func (p *Process) start() error {
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout
setProcAttributes(p.cmd)
p.cmdMutex.Lock()
p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()
p.failedStartCount++ // this will be reset to zero when the process has successfully started
@@ -212,7 +271,7 @@ func (p *Process) start() error {
// Set process state to failed
if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state
p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr,
@@ -285,10 +344,12 @@ func (p *Process) start() error {
return
}
// wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
// skip the TTL check if there are inflight requests
if p.inFlightRequestsCount.Load() != 0 {
continue
}
if time.Since(p.lastRequestHandled) > maxDuration {
if time.Since(p.getLastRequestHandled()) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop()
return
@@ -344,7 +405,7 @@ func (p *Process) Shutdown() {
p.stopCommand()
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown
p.forceState(StateShutdown)
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -355,13 +416,18 @@ func (p *Process) stopCommand() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}()
if p.cancelUpstream == nil {
p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
p.cmdMutex.RUnlock()
if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return
}
p.cancelUpstream()
<-p.cmdWaitChan
cancelUpstream()
<-cmdWaitChan
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
@@ -399,6 +465,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if p.reverseProxy == nil {
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
return
}
requestBeginTime := time.Now()
var startDuration time.Duration
@@ -418,72 +490,75 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
}
p.inFlightRequests.Add(1)
p.inFlightRequestsCount.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.setLastRequestHandled(time.Now())
p.inFlightRequestsCount.Add(-1)
p.inFlightRequests.Done()
}()
// for #366
// - extract streaming param from request context, should have been set by proxymanager
var srw *statusResponseWriter
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
// start the process on demand
if p.CurrentState() != StateReady {
// start a goroutine to stream loading status messages into the response writer
// add a sync so the streaming client only runs when the goroutine has exited
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
// PR #417 (no support for anthropic v1/messages yet)
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
srw = newStatusResponseWriter(p, w)
go srw.statusUpdates(swapCtx)
} else {
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
}
beginStartTime := time.Now()
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway)
cancelLoadCtx()
if srw != nil {
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
// before closing the connection. Without this, the connection would close before
// the goroutine can write its cleanup messages, causing incomplete SSE output.
srw.waitForCompletion(100 * time.Millisecond)
} else {
http.Error(w, errstr, http.StatusBadGateway)
}
return
}
startDuration = time.Since(beginStartTime)
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
// should trigger srw to stop sending loading events ...
cancelLoadCtx()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
w.Header().Set("X-Accel-Buffering", "no")
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
// recover from http.ErrAbortHandler panics that can occur when the client
// disconnects before the response is sent
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
} else {
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}()
if srw != nil {
// Wait for the goroutine to finish writing its final messages
const completionTimeout = 1 * time.Second
if !srw.waitForCompletion(completionTimeout) {
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
}
p.reverseProxy.ServeHTTP(srw, r)
} else {
p.reverseProxy.ServeHTTP(w, r)
}
totalTime := time.Since(requestBeginTime)
@@ -519,13 +594,16 @@ func (p *Process) waitForCmd() {
case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped
p.forceState(StateStopped)
}
default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state
p.forceState(StateStopped) // force it to be in this state
}
p.cmdMutex.Lock()
close(p.cmdWaitChan)
p.cmdMutex.Unlock()
}
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
@@ -551,6 +629,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
setProcAttributes(stopCmd)
stopCmd.Env = p.cmd.Env
if err := stopCmd.Run(); err != nil {
@@ -566,3 +645,228 @@ func (p *Process) cmdStopUpstreamProcess() error {
return nil
}
var loadingRemarks = []string{
"Still faster than your last standup meeting...",
"Reticulating splines...",
"Waking up the hamsters...",
"Teaching the model manners...",
"Convincing the GPU to participate...",
"Loading weights (they're heavy)...",
"Herding electrons...",
"Compiling excuses for the delay...",
"Downloading more RAM...",
"Asking the model nicely to boot up...",
"Bribing CUDA with cookies...",
"Still loading (blame VRAM)...",
"The model is fashionably late...",
"Warming up those tensors...",
"Making the neural net do push-ups...",
"Your patience is appreciated (really)...",
"Almost there (probably)...",
"Loading like it's 1999...",
"The model forgot where it put its keys...",
"Quantum tunneling through layers...",
"Negotiating with the PCIe bus...",
"Defrosting frozen parameters...",
"Teaching attention heads to focus...",
"Running the matrix (slowly)...",
"Untangling transformer blocks...",
"Calibrating the flux capacitor...",
"Spinning up the probability wheels...",
"Waiting for the GPU to wake from its nap...",
"Converting caffeine to compute...",
"Allocating virtual patience...",
"Performing arcane CUDA rituals...",
"The model is stuck in traffic...",
"Inflating embeddings...",
"Summoning computational demons...",
"Pleading with the OOM killer...",
"Calculating the meaning of life (still at 42)...",
"Training the training wheels...",
"Optimizing the optimizer...",
"Bootstrapping the bootstrapper...",
"Loading loading screen...",
"Processing processing logs...",
"Buffering buffer overflow jokes...",
"The model hit snooze...",
"Debugging the debugger...",
"Compiling the compiler...",
"Parsing the parser (meta)...",
"Tokenizing tokens...",
"Encoding the encoder...",
"Hashing hash browns...",
"Forking spoons (not forks)...",
"The model is contemplating existence...",
"Transcending dimensional barriers...",
"Invoking elder tensor gods...",
"Unfurling probability clouds...",
"Synchronizing parallel universes...",
"The GPU is having second thoughts...",
"Recalibrating reality matrices...",
"Time is an illusion, loading doubly so...",
"Convincing bits to flip themselves...",
"The model is reading its own documentation...",
}
type statusResponseWriter struct {
hasWritten bool
writer http.ResponseWriter
process *Process
wg sync.WaitGroup // Track goroutine completion
start time.Time
}
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
s := &statusResponseWriter{
writer: w,
process: p,
start: time.Now(),
}
s.Header().Set("Content-Type", "text/event-stream") // SSE
s.Header().Set("Cache-Control", "no-cache") // no-cache
s.Header().Set("Connection", "keep-alive") // keep-alive
s.WriteHeader(http.StatusOK) // send status code 200
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
return s
}
// statusUpdates sends status updates to the client while the model is loading
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
s.wg.Add(1)
defer s.wg.Done()
// Recover from panics caused by client disconnection
// Note: recover() only works within the same goroutine, so we need it here
defer func() {
if r := recover(); r != nil {
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
}
}()
defer func() {
duration := time.Since(s.start)
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
// Create a shuffled copy of loadingRemarks
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
// Pick a random duration to send a remark
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Now()
ticker := time.NewTicker(time.Second)
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if s.process.CurrentState() == StateReady {
return
}
// Check if it's time for a snarky remark
if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendLine(fmt.Sprintf("\n%s", remark))
lastRemarkTime = time.Now()
// Pick a new random duration for the next remark
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
// waitForCompletion waits for the statusUpdates goroutine to finish
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
func (s *statusResponseWriter) sendLine(line string) {
s.sendData(line + "\n")
}
func (s *statusResponseWriter) sendData(data string) {
// Create the proper SSE JSON structure
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
return
}
// Write SSE formatted data, panic if not able to write
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
if err != nil {
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
}
s.Flush()
}
func (s *statusResponseWriter) Header() http.Header {
return s.writer.Header()
}
func (s *statusResponseWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *statusResponseWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
// Add Flush method
func (s *statusResponseWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+74 -1
View File
@@ -436,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else {
assert.Contains(t, w.Body.String(), "unexpected EOF")
// Upstream may be killed mid-response.
// Assert an incomplete or partial response.
assert.NotEqual(t, "12345", w.Body.String())
}
close(waitChan)
@@ -492,3 +494,74 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
}
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
// handled appropriately.
//
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
// can't copy the body. This can be caused by a client disconnecting before the full
// response is sent from some reason.
//
// bug: https://github.com/mostlygeek/llama-swap/issues/362
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
defer func() {
if r := recover(); r != nil {
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
}
}()
expectedMessage := "panic_test"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
defer process.Stop()
// Start the process
err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())
// Create a custom ResponseWriter that simulates a client disconnect
// by panicking when Write is called after headers are sent
panicWriter := &panicOnWriteResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
shouldPanic: true,
}
// Make a request that will trigger the panic
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
// ProxyRequest should catch and handle this panic gracefully.
process.ProxyRequest(panicWriter, req)
// If we get here, the panic was properly recovered in ProxyRequest
// The process should still be in a ready state
assert.Equal(t, StateReady, process.CurrentState())
}
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
// to simulate a client disconnect after headers are sent
// used by: TestProcess_ReverseProxyPanicIsHandled
type panicOnWriteResponseWriter struct {
*httptest.ResponseRecorder
shouldPanic bool
headerWritten bool
}
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
w.headerWritten = true
w.ResponseRecorder.WriteHeader(statusCode)
}
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
if w.shouldPanic && w.headerWritten {
// Simulate the panic that httputil.ReverseProxy throws
panic(http.ErrAbortHandler)
}
return w.ResponseRecorder.Write(b)
}
+12
View File
@@ -0,0 +1,12 @@
//go:build !windows
package proxy
import (
"os/exec"
)
// setProcAttributes sets platform-specific process attributes
func setProcAttributes(cmd *exec.Cmd) {
// No-op on Unix systems
}
+16
View File
@@ -0,0 +1,16 @@
//go:build windows
package proxy
import (
"os/exec"
"syscall"
)
// setProcAttributes sets platform-specific process attributes
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
}
}
+141 -37
View File
@@ -25,6 +25,8 @@ const (
PROFILE_SPLIT_CHAR = ":"
)
type proxyCtxKey string
type ProxyManager struct {
sync.Mutex
@@ -36,13 +38,18 @@ type ProxyManager struct {
upstreamLogger *LogMonitor
muxLogger *LogMonitor
metricsMonitor *MetricsMonitor
metricsMonitor *metricsMonitor
processGroups map[string]*ProcessGroup
// shutdown signaling
shutdownCtx context.Context
shutdownCancel context.CancelFunc
// version info
buildDate string
commit string
version string
}
func New(config config.Config) *ProxyManager {
@@ -73,8 +80,39 @@ func New(config config.Config) *ProxyManager {
upstreamLogger.SetLogLevel(LevelInfo)
}
// see: https://go.dev/src/time/format.go
timeFormats := map[string]string{
"ansic": time.ANSIC,
"unixdate": time.UnixDate,
"rubydate": time.RubyDate,
"rfc822": time.RFC822,
"rfc822z": time.RFC822Z,
"rfc850": time.RFC850,
"rfc1123": time.RFC1123,
"rfc1123z": time.RFC1123Z,
"rfc3339": time.RFC3339,
"rfc3339nano": time.RFC3339Nano,
"kitchen": time.Kitchen,
"stamp": time.Stamp,
"stampmilli": time.StampMilli,
"stampmicro": time.StampMicro,
"stampnano": time.StampNano,
}
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok {
proxyLogger.SetLogTimeFormat(timeFormat)
upstreamLogger.SetLogTimeFormat(timeFormat)
}
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int
if config.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback
} else {
maxMetrics = config.MetricsMaxInMemory
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
@@ -83,12 +121,16 @@ func New(config config.Config) *ProxyManager {
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
metricsMonitor: NewMetricsMonitor(&config),
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
processGroups: make(map[string]*ProcessGroup),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
buildDate: "unknown",
commit: "abcd1234",
version: "0",
}
// create the process groups
@@ -131,7 +173,15 @@ func New(config config.Config) *ProxyManager {
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) {
// don't log the Wake on Lan proxy health check
if c.Request.URL.Path == "/wol-health" {
c.Next()
return
}
// Start timer
start := time.Now()
@@ -185,30 +235,30 @@ func (pm *ProxyManager) setupGinEngine() {
c.Next()
})
mm := MetricsMiddleware(pm)
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -235,6 +285,11 @@ func (pm *ProxyManager) setupGinEngine() {
c.String(http.StatusOK, "OK")
})
// see cmd/wol-proxy/wol-proxy.go, not logged
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
@@ -356,28 +411,40 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
continue
}
record := gin.H{
"id": id,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
newRecord := func(modelId string) gin.H {
record := gin.H{
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
data = append(data, newRecord(id))
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
// Include aliases
if pm.config.IncludeAliasesInList {
for _, alias := range modelConfig.Aliases {
if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias))
}
}
}
data = append(data, record)
}
// Sort by the "id" key
@@ -461,11 +528,26 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
}
// rewrite the path
originalPath := c.Request.URL.Path
c.Request.URL.Path = remainingPath
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
return
}
}
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
@@ -522,10 +604,24 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
}
@@ -685,3 +781,11 @@ func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
}
return nil
}
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
pm.Lock()
defer pm.Unlock()
pm.buildDate = buildDate
pm.commit = commit
pm.version = version
}
+11 -2
View File
@@ -28,6 +28,7 @@ func addApiHandlers(pm *ProxyManager) {
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
apiGroup.GET("/version", pm.apiGetVersion)
}
}
@@ -180,7 +181,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
sendMetrics(pm.metricsMonitor.GetMetrics())
sendMetrics(pm.metricsMonitor.getMetrics())
for {
select {
@@ -198,7 +199,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
jsonData, err := pm.metricsMonitor.getMetricsJSON()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return
@@ -227,3 +228,11 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
c.String(http.StatusOK, "OK")
}
}
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
c.JSON(http.StatusOK, map[string]string{
"version": pm.version,
"commit": pm.commit,
"build_date": pm.buildDate,
})
}
+170 -102
View File
@@ -21,6 +21,32 @@ import (
"github.com/tidwall/gjson"
)
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
// The tests can panic otherwise:
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
// See: https://github.com/gin-gonic/gin/issues/1815
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
@@ -37,7 +63,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -74,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
t.Run(requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -116,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
for _, requestedModel := range tests {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -159,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
@@ -212,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
// Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Add("Origin", "i-am-the-origin")
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
// Call the listModelsHandler
proxy.ServeHTTP(w, req)
@@ -311,7 +337,7 @@ models:
proxy := New(processedConfig)
req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -387,7 +413,7 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Request models list
req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -411,6 +437,70 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
}
}
func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
// Configure alias
config := config.Config{
HealthCheckTimeout: 15,
IncludeAliasesInList: true,
Models: map[string]config.ModelConfig{
"model1": func() config.ModelConfig {
mc := getTestSimpleResponderConfig("model1")
mc.Name = "Model 1"
mc.Aliases = []string{"alias1"}
return mc
}(),
},
LogLevel: "error",
}
proxy := New(config)
// Request models list
req := httptest.NewRequest("GET", "/v1/models", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// We expect both base id and alias
var model1Data, alias1Data map[string]any
for _, model := range response.Data {
if model["id"] == "model1" {
model1Data = model
} else if model["id"] == "alias1" {
alias1Data = model
}
}
// Verify model1 has name
assert.NotNil(t, model1Data)
_, exists := model1Data["name"]
if !assert.True(t, exists, "model1 should have name key") {
t.FailNow()
}
name1, ok := model1Data["name"].(string)
assert.True(t, ok, "name1 should be a string")
// Verify alias1 has name
assert.NotNil(t, alias1Data)
_, exists = alias1Data["name"]
if !assert.True(t, exists, "alias1 should have name key") {
t.FailNow()
}
name2, ok := alias1Data["name"].(string)
assert.True(t, ok, "name2 should be a string")
// Name keys should match
assert.Equal(t, name1, name2)
}
func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -448,7 +538,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.ServeHTTP(w, req)
@@ -476,12 +566,12 @@ func TestProxyManager_Unload(t *testing.T) {
proxy := New(conf)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder()
w = CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK")
@@ -519,7 +609,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
}
@@ -527,7 +617,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
if !assert.Equal(t, w.Body.String(), "OK") {
@@ -571,7 +661,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -589,13 +679,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Load just a model.
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder()
w = CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
var response RunningResponse
@@ -647,7 +737,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
// Verify the response
@@ -682,7 +772,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -716,7 +806,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
// Verify the response
@@ -784,7 +874,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
@@ -812,7 +902,7 @@ models:
defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("main model name", func(t *testing.T) {
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
@@ -820,7 +910,7 @@ models:
t.Run("model alias", func(t *testing.T) {
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
@@ -841,7 +931,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -869,7 +959,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -885,76 +975,6 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
// t.Logf("%v", response)
}
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a non-streaming request
reqBody := `{"model":"model1", "stream": false}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a streaming request
reqBody := `{"model":"model1", "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_HealthEndpoint(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
@@ -967,7 +987,7 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
@@ -988,7 +1008,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -1075,18 +1095,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
for _, endpoint := range endpoints {
t.Run(endpoint, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
req := httptest.NewRequest("GET", endpoint, nil)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
// We don't need the handler to fully complete, just to set the headers
// so run it in a goroutine and check the headers after a short delay
go proxy.ServeHTTP(rec, req)
time.Sleep(10 * time.Millisecond) // give it time to start and write headers
// Run handler in goroutine and wait for context timeout
done := make(chan struct{})
go func() {
defer close(done)
proxy.ServeHTTP(rec, req)
}()
// Wait for either the handler to complete or context to timeout
<-ctx.Done()
// At this point, the handler has either finished or been cancelled
// Wait for the goroutine to fully exit before reading
<-done
// Now it's safe to read from rec - no more concurrent writes
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
})
@@ -1109,7 +1139,7 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
reqBody := `{"model":"streaming-model"}`
// simple-responder will return text/event-stream when stream=true is in the query
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
rec := httptest.NewRecorder()
rec := CreateTestResponseRecorder()
proxy.ServeHTTP(rec, req)
@@ -1117,3 +1147,41 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
}
func TestProxyManager_ApiGetVersion(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
// Version test map
versionTest := map[string]string{
"build_date": "1970-01-01T00:00:00Z",
"commit": "cc915ddb6f04a42d9cd1f524e1d46ec6ed069fdc",
"version": "v001",
}
proxy := New(config)
proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"])
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/api/version", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Ensure json response
assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))
// Check for attributes
response := map[string]string{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
for key, value := range versionTest {
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
}
}
+64 -62
View File
@@ -752,9 +752,9 @@
}
},
"node_modules/@eslint-community/eslint-utils": {
"version": "4.7.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
"integrity": "sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==",
"version": "4.9.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.9.0.tgz",
"integrity": "sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -794,13 +794,13 @@
}
},
"node_modules/@eslint/config-array": {
"version": "0.20.0",
"resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.20.0.tgz",
"integrity": "sha512-fxlS1kkIjx8+vy2SjuCB94q3htSNrufYTXubwiBFeaQHbH6Ipi43gFJq2zCMt6PHhImH3Xmr0NksKDvchWlpQQ==",
"version": "0.21.1",
"resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.21.1.tgz",
"integrity": "sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"@eslint/object-schema": "^2.1.6",
"@eslint/object-schema": "^2.1.7",
"debug": "^4.3.1",
"minimatch": "^3.1.2"
},
@@ -809,19 +809,22 @@
}
},
"node_modules/@eslint/config-helpers": {
"version": "0.2.2",
"resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.2.2.tgz",
"integrity": "sha512-+GPzk8PlG0sPpzdU5ZvIRMPidzAnZDl/s9L+y13iodqvb8leL53bTannOrQ/Im7UkpsmFU5Ily5U60LWixnmLg==",
"version": "0.4.2",
"resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.4.2.tgz",
"integrity": "sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"@eslint/core": "^0.17.0"
},
"engines": {
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
}
},
"node_modules/@eslint/core": {
"version": "0.14.0",
"resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.14.0.tgz",
"integrity": "sha512-qIbV0/JZr7iSDjqAc60IqbLdsj9GDt16xQtWD+B78d/HAlvysGdZZ6rpJHGAc2T0FQx1X6thsSPdnoiGKdNtdg==",
"version": "0.17.0",
"resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.17.0.tgz",
"integrity": "sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
@@ -869,9 +872,9 @@
}
},
"node_modules/@eslint/js": {
"version": "9.28.0",
"resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.28.0.tgz",
"integrity": "sha512-fnqSjGWd/CoIp4EXIxWVK/sHA6DOHN4+8Ix2cX5ycOY7LG0UY8nHCU5pIp2eaE1Mc7Qd8kHspYNzYXT2ojPLzg==",
"version": "9.39.1",
"resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.39.1.tgz",
"integrity": "sha512-S26Stp4zCy88tH94QbBv3XCuzRQiZ9yXofEILmglYTh/Ug/a9/umqvgFtYBAo3Lp0nsI/5/qH1CCrbdK3AP1Tw==",
"dev": true,
"license": "MIT",
"engines": {
@@ -882,9 +885,9 @@
}
},
"node_modules/@eslint/object-schema": {
"version": "2.1.6",
"resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.6.tgz",
"integrity": "sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==",
"version": "2.1.7",
"resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.7.tgz",
"integrity": "sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA==",
"dev": true,
"license": "Apache-2.0",
"engines": {
@@ -892,13 +895,13 @@
}
},
"node_modules/@eslint/plugin-kit": {
"version": "0.3.1",
"resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.1.tgz",
"integrity": "sha512-0J+zgWxHN+xXONWIyPWKFMgVuJoZuGiIFu8yxk7RJjxkzpGmyja5wRFqZIVtjDVOQpV+Rw0iOAjYPE2eQyjr0w==",
"version": "0.4.1",
"resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.4.1.tgz",
"integrity": "sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"@eslint/core": "^0.14.0",
"@eslint/core": "^0.17.0",
"levn": "^0.4.1"
},
"engines": {
@@ -1908,9 +1911,9 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz",
"integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==",
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -2010,9 +2013,9 @@
}
},
"node_modules/acorn": {
"version": "8.14.1",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.1.tgz",
"integrity": "sha512-OvQ/2pUDKmgfCg++xsTX1wGxfTaszcHVcTctW4UJB4hibJx2HXxxO5UmVgyjMa+ZDsiaf5wWLXYpRWMmBI0QHg==",
"version": "8.15.0",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"dev": true,
"license": "MIT",
"bin": {
@@ -2080,9 +2083,9 @@
"license": "MIT"
},
"node_modules/brace-expansion": {
"version": "1.1.11",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz",
"integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==",
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -2380,33 +2383,32 @@
}
},
"node_modules/eslint": {
"version": "9.28.0",
"resolved": "https://registry.npmjs.org/eslint/-/eslint-9.28.0.tgz",
"integrity": "sha512-ocgh41VhRlf9+fVpe7QKzwLj9c92fDiqOj8Y3Sd4/ZmVA4Btx4PlUYPq4pp9JDyupkf1upbEXecxL2mwNV7jPQ==",
"version": "9.39.1",
"resolved": "https://registry.npmjs.org/eslint/-/eslint-9.39.1.tgz",
"integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==",
"dev": true,
"license": "MIT",
"dependencies": {
"@eslint-community/eslint-utils": "^4.2.0",
"@eslint-community/eslint-utils": "^4.8.0",
"@eslint-community/regexpp": "^4.12.1",
"@eslint/config-array": "^0.20.0",
"@eslint/config-helpers": "^0.2.1",
"@eslint/core": "^0.14.0",
"@eslint/config-array": "^0.21.1",
"@eslint/config-helpers": "^0.4.2",
"@eslint/core": "^0.17.0",
"@eslint/eslintrc": "^3.3.1",
"@eslint/js": "9.28.0",
"@eslint/plugin-kit": "^0.3.1",
"@eslint/js": "9.39.1",
"@eslint/plugin-kit": "^0.4.1",
"@humanfs/node": "^0.16.6",
"@humanwhocodes/module-importer": "^1.0.1",
"@humanwhocodes/retry": "^0.4.2",
"@types/estree": "^1.0.6",
"@types/json-schema": "^7.0.15",
"ajv": "^6.12.4",
"chalk": "^4.0.0",
"cross-spawn": "^7.0.6",
"debug": "^4.3.2",
"escape-string-regexp": "^4.0.0",
"eslint-scope": "^8.3.0",
"eslint-visitor-keys": "^4.2.0",
"espree": "^10.3.0",
"eslint-scope": "^8.4.0",
"eslint-visitor-keys": "^4.2.1",
"espree": "^10.4.0",
"esquery": "^1.5.0",
"esutils": "^2.0.2",
"fast-deep-equal": "^3.1.3",
@@ -2464,9 +2466,9 @@
}
},
"node_modules/eslint-scope": {
"version": "8.3.0",
"resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.3.0.tgz",
"integrity": "sha512-pUNxi75F8MJ/GdeKtVLSbYg4ZI34J6C0C7sbL4YOp2exGwen7ZsuBqKzUhXd0qMQ362yET3z+uPwKeg/0C2XCQ==",
"version": "8.4.0",
"resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.4.0.tgz",
"integrity": "sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==",
"dev": true,
"license": "BSD-2-Clause",
"dependencies": {
@@ -2481,9 +2483,9 @@
}
},
"node_modules/eslint-visitor-keys": {
"version": "4.2.0",
"resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.0.tgz",
"integrity": "sha512-UyLnSehNt62FFhSwjZlHmeokpRK59rcz29j+F1/aDgbkbRTk7wIc9XzdoasMUbRNKDM0qQt/+BJ4BrpFeABemw==",
"version": "4.2.1",
"resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz",
"integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==",
"dev": true,
"license": "Apache-2.0",
"engines": {
@@ -2494,15 +2496,15 @@
}
},
"node_modules/espree": {
"version": "10.3.0",
"resolved": "https://registry.npmjs.org/espree/-/espree-10.3.0.tgz",
"integrity": "sha512-0QYC8b24HWY8zjRnDTL6RiHfDbAWn63qb4LMj1Z4b076A4une81+z03Kg7l7mn/48PUTqoLptSXez8oknU8Clg==",
"version": "10.4.0",
"resolved": "https://registry.npmjs.org/espree/-/espree-10.4.0.tgz",
"integrity": "sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==",
"dev": true,
"license": "BSD-2-Clause",
"dependencies": {
"acorn": "^8.14.0",
"acorn": "^8.15.0",
"acorn-jsx": "^5.3.2",
"eslint-visitor-keys": "^4.2.0"
"eslint-visitor-keys": "^4.2.1"
},
"engines": {
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
@@ -2852,9 +2854,9 @@
"license": "MIT"
},
"node_modules/js-yaml": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz",
"integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==",
"version": "4.1.1",
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz",
"integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -3975,9 +3977,9 @@
}
},
"node_modules/vite": {
"version": "6.3.5",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.3.5.tgz",
"integrity": "sha512-cZn6NDFE7wdTpINgs++ZJ4N49W2vRp8LCKrn3Ob1kYNtOo21vfDoaV5GzBfLU4MovSAB8uNRm4jgzVQZ+mBzPQ==",
"version": "6.4.1",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"dependencies": {
+2 -2
View File
@@ -2,7 +2,7 @@ import { useAPI } from "../contexts/APIProvider";
import { useMemo } from "react";
const ConnectionStatusIcon = () => {
const { connectionStatus } = useAPI();
const { connectionStatus, versionInfo } = useAPI();
const eventStatusColor = useMemo(() => {
switch (connectionStatus) {
@@ -17,7 +17,7 @@ const ConnectionStatusIcon = () => {
}, [connectionStatus]);
return (
<div className="flex items-center" title={`event stream: ${connectionStatus}`}>
<div className="flex items-center" title={`Event Stream: ${connectionStatus ?? 'unknown'}\nAPI Version: ${versionInfo?.version ?? 'unknown'}\nCommit Hash: ${versionInfo?.commit?.substring(0,7) ?? 'unknown'}\nBuild Date: ${versionInfo?.build_date ?? 'unknown'}`}>
<span className={`inline-block w-3 h-3 rounded-full ${eventStatusColor} mr-2`}></span>
</div>
);
+2 -2
View File
@@ -5,7 +5,7 @@ import { useTheme } from "../contexts/ThemeProvider";
import ConnectionStatusIcon from "./ConnectionStatus";
export function Header() {
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle } = useTheme();
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
const handleTitleChange = useCallback(
(newTitle: string) => {
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
@@ -17,7 +17,7 @@ export function Header() {
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
return (
<header className="flex items-center justify-between bg-surface border-b border-border p-2 px-4 h-[75px]">
<header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
{screenWidth !== "xs" && screenWidth !== "sm" && (
<h1
contentEditable
+35 -1
View File
@@ -23,6 +23,7 @@ interface APIProviderType {
upstreamLogs: string;
metrics: Metrics[];
connectionStatus: ConnectionState;
versionInfo: VersionInfo;
}
interface Metrics {
@@ -41,11 +42,18 @@ interface LogData {
source: "upstream" | "proxy";
data: string;
}
interface APIEventEnvelope {
type: "modelStatus" | "logData" | "metrics";
data: string;
}
interface VersionInfo {
build_date: string;
commit: string;
version: string;
}
const APIContext = createContext<APIProviderType | undefined>(undefined);
type APIProviderProps = {
children: ReactNode;
@@ -59,6 +67,11 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
const [upstreamLogs, setUpstreamLogs] = useState("");
const [metrics, setMetrics] = useState<Metrics[]>([]);
const [connectionStatus, setConnectionState] = useState<ConnectionState>("disconnected");
const [versionInfo, setVersionInfo] = useState<VersionInfo>({
build_date: "unknown",
commit: "unknown",
version: "unknown"
});
//const apiEventSource = useRef<EventSource | null>(null);
const [models, setModels] = useState<Model[]>([]);
@@ -152,6 +165,26 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
connect();
}, []);
useEffect(() => {
// fetch version
const fetchVersion = async () => {
try {
const response = await fetch("/api/version");
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data: VersionInfo = await response.json();
setVersionInfo(data);
} catch (error) {
console.error(error);
}
};
if (connectionStatus === 'connected') {
fetchVersion();
}
}, [connectionStatus]);
useEffect(() => {
if (autoStartAPIEvents) {
enableAPIEvents(true);
@@ -230,8 +263,9 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
upstreamLogs,
metrics,
connectionStatus,
versionInfo,
}),
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
[models, listModels, unloadAllModels, unloadSingleModel, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics, connectionStatus, versionInfo]
);
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
+346 -42
View File
@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri";
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
export default function ModelsPage() {
const { isNarrow } = useTheme();
@@ -38,9 +38,11 @@ export default function ModelsPage() {
function ModelsPanel() {
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
const { isNarrow } = useTheme();
const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
const [menuOpen, setMenuOpen] = useState(false);
const filteredModels = useMemo(() => {
return models.filter((model) => showUnlisted || !model.unlisted);
@@ -66,33 +68,77 @@ function ModelsPanel() {
return (
<div className="card h-full flex flex-col">
<div className="shrink-0">
<h2>Models</h2>
<div className="flex justify-between">
<div className="flex gap-2">
<button
className="btn text-base flex items-center gap-2"
onClick={toggleIdorName}
style={{ lineHeight: "1.2" }}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
</button>
<div className="flex justify-between items-baseline">
<h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
{isNarrow && (
<div className="relative">
<button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
<RiMenuFill size="20" />
</button>
{menuOpen && (
<div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
toggleIdorName();
setMenuOpen(false);
}}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
</button>
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
setShowUnlisted(!showUnlisted);
setMenuOpen(false);
}}
>
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
</button>
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
handleUnloadAllModels();
setMenuOpen(false);
}}
disabled={isUnloading}
>
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
)}
</div>
)}
</div>
{!isNarrow && (
<div className="flex justify-between">
<div className="flex gap-2">
<button
className="btn text-base flex items-center gap-2"
onClick={toggleIdorName}
style={{ lineHeight: "1.2" }}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
</button>
<button
className="btn text-base flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
>
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
</button>
</div>
<button
className="btn text-base flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
onClick={handleUnloadAllModels}
disabled={isUnloading}
>
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
<button
className="btn text-base flex items-center gap-2"
onClick={handleUnloadAllModels}
disabled={isUnloading}
>
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
)}
</div>
<div className="flex-1 overflow-y-auto">
@@ -145,42 +191,300 @@ function ModelsPanel() {
);
}
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
function TokenHistogram({ data }: { data: HistogramData }) {
const { bins, min, max, p50, p95, p99 } = data;
const maxCount = Math.max(...bins);
const height = 120;
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
// Use viewBox for responsive sizing
const viewBoxWidth = 600;
const chartWidth = viewBoxWidth - padding.left - padding.right;
const chartHeight = height - padding.top - padding.bottom;
const barWidth = chartWidth / bins.length;
const range = max - min;
// Calculate x position for a given value
const getXPosition = (value: number) => {
return padding.left + ((value - min) / range) * chartWidth;
};
return (
<div className="mt-2 w-full">
<svg
viewBox={`0 0 ${viewBoxWidth} ${height}`}
className="w-full h-auto"
preserveAspectRatio="xMidYMid meet"
>
{/* Y-axis */}
<line
x1={padding.left}
y1={padding.top}
x2={padding.left}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="1"
opacity="0.3"
/>
{/* X-axis */}
<line
x1={padding.left}
y1={height - padding.bottom}
x2={viewBoxWidth - padding.right}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="1"
opacity="0.3"
/>
{/* Histogram bars */}
{bins.map((count, i) => {
const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0;
const x = padding.left + i * barWidth;
const y = height - padding.bottom - barHeight;
const binStart = min + i * data.binSize;
const binEnd = binStart + data.binSize;
return (
<g key={i}>
<rect
x={x}
y={y}
width={Math.max(barWidth - 1, 1)}
height={barHeight}
fill="currentColor"
opacity="0.6"
className="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
/>
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
</g>
);
})}
{/* Percentile lines */}
<line
x1={getXPosition(p50)}
y1={padding.top}
x2={getXPosition(p50)}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="2"
strokeDasharray="4 2"
opacity="0.7"
className="text-gray-600 dark:text-gray-400"
/>
<line
x1={getXPosition(p95)}
y1={padding.top}
x2={getXPosition(p95)}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="2"
strokeDasharray="4 2"
opacity="0.7"
className="text-orange-500 dark:text-orange-400"
/>
<line
x1={getXPosition(p99)}
y1={padding.top}
x2={getXPosition(p99)}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="2"
strokeDasharray="4 2"
opacity="0.7"
className="text-green-500 dark:text-green-400"
/>
{/* X-axis labels */}
<text
x={padding.left}
y={height - 5}
fontSize="10"
fill="currentColor"
opacity="0.6"
textAnchor="start"
>
{min.toFixed(1)}
</text>
<text
x={viewBoxWidth - padding.right}
y={height - 5}
fontSize="10"
fill="currentColor"
opacity="0.6"
textAnchor="end"
>
{max.toFixed(1)}
</text>
{/* X-axis label */}
<text
x={padding.left + chartWidth / 2}
y={height - 2}
fontSize="10"
fill="currentColor"
opacity="0.6"
textAnchor="middle"
>
Tokens/Second Distribution
</text>
</svg>
</div>
);
}
function StatsPanel() {
const { metrics } = useAPI();
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
const [totalRequests, totalInputTokens, totalOutputTokens, tokenStats, histogramData] = useMemo(() => {
const totalRequests = metrics.length;
if (totalRequests === 0) {
return [0, 0, 0];
return [0, 0, 0, { p99: 0, p95: 0, p50: 0 }, null];
}
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
// Calculate token statistics using output_tokens and duration_ms
// Filter out metrics with invalid duration or output tokens
const validMetrics = metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
if (validMetrics.length === 0) {
return [totalRequests, totalInputTokens, totalOutputTokens, { p99: 0, p95: 0, p50: 0 }, null];
}
// Calculate tokens/second for each valid metric
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
// Sort for percentile calculation
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
// Calculate percentiles - showing speed thresholds where X% of requests are SLOWER (below)
// P99: 99% of requests are slower than this speed (99th percentile - fast requests)
// P95: 95% of requests are slower than this speed (95th percentile)
// P50: 50% of requests are slower than this speed (median)
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
// Create histogram data
const min = Math.min(...tokensPerSecond);
const max = Math.max(...tokensPerSecond);
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5))); // Adaptive bin count
const binSize = (max - min) / binCount;
const bins = Array(binCount).fill(0);
tokensPerSecond.forEach((value) => {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
});
const histogramData = {
bins,
min,
max,
binSize,
p99,
p95,
p50,
};
return [
totalRequests,
totalInputTokens,
totalOutputTokens,
{
p99: p99.toFixed(2),
p95: p95.toFixed(2),
p50: p50.toFixed(2),
},
histogramData,
];
}, [metrics]);
const nf = new Intl.NumberFormat();
return (
<div className="card">
<div className="rounded-lg overflow-hidden border border-gray-200 dark:border-white/10">
<table className="w-full">
<thead>
<tr className="border-b border-gray-200 dark:border-white/10 text-right">
<th>Requests</th>
<th className="border-l border-gray-200 dark:border-white/10">Processed</th>
<th className="border-l border-gray-200 dark:border-white/10">Generated</th>
<th className="border-l border-gray-200 dark:border-white/10">Tokens/Sec</th>
<div className="rounded-lg overflow-hidden border border-card-border-inner">
<table className="min-w-full divide-y divide-card-border-inner">
<thead className="bg-secondary">
<tr>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">
Requests
</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Processed
</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Generated
</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Token Stats (tokens/sec)
</th>
</tr>
</thead>
<tbody>
<tr className="text-right">
<td className="border-r border-gray-200 dark:border-white/10">{totalRequests}</td>
<td className="border-r border-gray-200 dark:border-white/10">
{new Intl.NumberFormat().format(totalInputTokens)}
<tbody className="bg-surface divide-y divide-card-border-inner">
<tr className="hover:bg-secondary">
<td className="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">{totalRequests}</td>
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div className="flex items-center gap-2">
<span className="text-sm font-medium">{nf.format(totalInputTokens)}</span>
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td className="border-r border-gray-200 dark:border-white/10">
{new Intl.NumberFormat().format(totalOutputTokens)}
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div className="flex items-center gap-2">
<span className="text-sm font-medium">{nf.format(totalOutputTokens)}</span>
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td className="px-4 py-4 border-l border-gray-200 dark:border-white/10">
<div className="space-y-3">
<div className="grid grid-cols-3 gap-2 items-center">
<div className="text-center">
<div className="text-xs text-gray-500 dark:text-gray-400">P50</div>
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{tokenStats.p50}
</div>
</div>
<div className="text-center">
<div className="text-xs text-gray-500 dark:text-gray-400">P95</div>
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{tokenStats.p95}
</div>
</div>
<div className="text-center">
<div className="text-xs text-gray-500 dark:text-gray-400">P99</div>
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{tokenStats.p99}
</div>
</div>
</div>
{histogramData && <TokenHistogram data={histogramData} />}
</div>
</td>
<td>{avgTokensPerSecond}</td>
</tr>
</tbody>
</table>