Compare commits

...

60 Commits

Author SHA1 Message Date
Benson Wong 22e098ac8b Add Peer Model Support (#438)
This PR allows a single llama-swap to be the central proxy for models served by other inference servers. The peer servers can be another llama-swap or any API that supports the /v1/* inference endpoint.

Updates: #433, #299
Closes: #296
2025-12-27 20:18:06 -08:00
Benson Wong 9864f9f517 .coderabbit.yaml: disable annoying features 2025-12-23 23:53:06 -08:00
Benson Wong 53b32f3601 proxy: add API key support (#436)
Add configuration support for api keys that are enforced by llama-swap. Keys are stripped before sending them to upstream servers. 

Updates: #433, #50 and #251
2025-12-23 23:39:33 -08:00
Benson Wong 565c44766d config,proxy: add new configuration logToStdout (#432)
The new logToStdout option controls what is logged to stdout. The
default has been changed to just the proxy logs, which contain swap and
http request logs.

There are four supported settings: none, proxy, upstream, both. The
"both" setting is the legacy setting where everything was spewed to
stdout.
2025-12-21 22:23:31 -08:00
Benson Wong e6a9e210ba proxy: fix path bug in /logs/stream/{model_id} (#431)
A {model_id} containing a forward slash trips up gin's path param
parsing. This updates /logs/stream to work like /upstream where the
model_id is built up in parts and searched for in the configuration.

Updates #421
2025-12-21 21:47:14 -08:00
Benson Wong d3f329f924 proxy: Improve logging performance and allow separate log streaming (#421)
Replace container/ring.Ring with a custom circularBuffer that uses a
single contiguous []byte slice. This fixes the original implementation
which created 10,240 ring elements instead of 10KB of storage.

GetHistory is now 139x faster (145μs → 1μs) and uses 117x less memory
(1.2MB → 10KB). Allocations reduced from 2 to 1 per write operation.

Create a LogMonitor per proxy.Process, replacing the usage
of a shared one. The buffer in LogMonitor is lazy allocated on the first
call to Write and freed when the Process is stopped. This reduces
unnecessary memory usage when a model is not active.

The /logs/stream/{model_id} endpoint was added to stream logs from a
specific process.
2025-12-18 21:49:25 -08:00
Benson Wong 98879b38c1 docker: add /app to $PATH (#424)
Make it so llama-server can be called directly instead of with the full
path at /app/llama-server.

Fixes #423
Ref: #233
2025-12-06 22:58:29 -08:00
Benson Wong 7b3b0f5eae move header images around [skip ci] 2025-12-02 19:40:42 -08:00
Benson Wong 021ccceef1 README: update hero image 2025-12-02 19:37:03 -08:00
Benson Wong f03871c50a Update README.md
- add supported anthropic API 
- add example for docker hot reload support
2025-12-02 19:03:01 -08:00
Ryan Steed dc00d17abe docs: add documentation for non-root container images and security considerations (#416)
* docs: add documentation for non-root container images and security considerations
* docs: move container security section to dedicated file and update README links
2025-12-02 08:52:26 -08:00
Benson Wong dea98733c3 proxy: extract metrics for v1/messages (#419) 2025-11-29 23:51:20 -08:00
Benson Wong bccce5fa19 go.mod,ui/package-lock.json: dependency and security updates (#418) 2025-11-29 22:27:22 -08:00
Benson Wong c968da1b73 proxy: add support for anthropic v1/messages api (#417)
* proxy: add support for anthropic v1/messages api
* proxy: restrict loading message to /v1/chat/completions
2025-11-29 22:09:07 -08:00
Ryan Steed a883d68d4f feat: Add support for custom llama.cpp base image and forked llama-swap repositories (#396)
* feat: Add support for custom llama.cpp base image and forked llama-swap repositories

- Introduce BASE_LLAMACPP_IMAGE env var to customize llama.cpp base image
- Introduce LS_REPO env var to customize llama-swap source
- Use GITHUB_REPOSITORY env var to automatically detect forked repos
- Update container tagging to use dynamic repo paths
- Pass build args for BASE_IMAGE and LS_REPO to Containerfile
- Enable flexible release downloads from forked repositories

* chore: quote entire curl options, appease coderabbitai
2025-11-29 20:59:15 -08:00
Ryan Steed b1dec8b735 docker: build both root and non-root container images (#412)
Change the user back to root for containers. Additionally, built a "non-root" labeled container for users who wish to have the additional security of running llama-swap as a lower privileged user.
2025-11-25 10:44:13 -08:00
Nikesh Parajuli 06523d8c1e feat: add platform-specific process attributes support (#411)
Fixes issues on Windows showing new windows for every process llama-swap spawns.
2025-11-24 21:39:56 -08:00
Ryan Steed 86e9b93c37 proxy,ui: add version endpoint and display version info in UI (#395)
- Add /api/version endpoint to ProxyManager that returns build date, commit hash, and version
- Implement SetVersion method to configure version info in ProxyManager
- Add version info fetching to APIProvider and display in ConnectionStatus component
- Include version info in UI context and update dependencies
- Add tests for version endpoint functionality
2025-11-17 10:43:47 -08:00
Ryan Steed 3acace810f proxy: add configurable logging timestamp format (#401)
introduces a new configuration option logTimeFormat that allows customizing the timestamp in log messages using golang's built in time format constants. The default remains no timestamp.
2025-11-16 10:21:59 -08:00
Ryan Steed 554d29e87d feat: enhance model listing to include aliases (#400)
introduce includeAliasesInList as a new configuration setting (default false) that includes aliases in v1/models

Fixes #399
2025-11-15 14:35:26 -08:00
Benson Wong 3567b7df08 Update image in README.md for web UI section 2025-11-08 15:29:37 -08:00
Benson Wong 38738525c9 config.example.yaml: add modeline for schema validation 2025-11-08 15:08:55 -08:00
Benson Wong c0fc858193 Add configuration file JSON schema (#393)
* add json schema for configuration
* add GH action to validate schema
2025-11-08 15:04:14 -08:00
Benson Wong b429349e8a add /ui/ to wol-proxy polling (#388) 2025-11-08 14:16:12 -08:00
Ryan Steed eab2efd7b5 feat: improve llama.cpp base image tag for cpu (#391)
Refactor the container build script to resolve llama.cpp base image for CPU, also tag these builds accordingly.

- For CPU containers, now fetch the latest 'server' tagged llama.cpp image instead of using a generic 'server' tag
- Cleans up the docker build command to use dynamic BASE_TAG variable
- Maintains existing push functionality for built images
2025-11-08 09:56:49 -08:00
Benson Wong 6aedbe121a cmd/wol-proxy: show a loading page for / (#381)
When requesting / wol-proxy will show a loading page that polls /status
every second. When the upstream server is ready the loading page will
refresh causing the actual root page to be displayed
2025-11-03 19:37:06 -08:00
Ryan Steed b24467ab89 fix: update containerfile user/group management commands (#379)
- Replace `addgroup` with `groupadd` for system group creation
- Replace `adduser` with `useradd` for system user creation
- Maintain same functionality while using more standard POSIX commands
2025-11-03 17:17:40 -05:00
Benson Wong 12b69fb718 proxy: recover from panic in Process.statusUpdate (#378)
Process.statusUpdate() panics when it can not write data, usually from a
client disconnect. Since it runs in a goroutine and did not have a
recover() the result was a crash.

ref: https://github.com/mostlygeek/llama-swap/discussions/326#discussioncomment-14856197
2025-11-03 05:30:09 -08:00
Ryan Steed f91a8b2462 refactor: update Containerfile to support non-root user execution and improve security (#368)
Set default container user/group to lower privilege app user 

* refactor: update Containerfile to support non-root user execution and improve security

- Updated LS_VER argument from 89 to 170 to use the latest version
- Added UID/GID arguments with default values of 0 (root) for backward compatibility
- Added USER_HOME environment variable set to /root
- Implemented conditional user/group creation logic that only runs when UID/GID are not 0
- Created necessary directory structure with proper ownership using mkdir and chown commands
- Switched to non-root user execution for improved security posture
- Updated COPY instruction to use --chown flag for proper file ownership

* chore: update containerfile to use non-root user with proper UID/GID

- Changed default UID and GID from 0 (root) to 10001 for security best practices
- Updated USER_HOME from /root to /app to avoid running as root user
2025-10-31 17:01:04 -07:00
Benson Wong a89b803d4a Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response.

Fixes: #366
2025-10-29 00:09:39 -07:00
Benson Wong f852689104 proxy: add panic recovery to Process.ProxyRequest (#363)
Switching to use httputil.ReverseProxy in #342 introduced a possible
panic if a client disconnects while streaming the body. Since llama-swap
does not use http.Server the recover() is not automatically there.

- introduce a recover() in Process.ProxyRequest to recover and log the
  event
- add TestProcess_ReverseProxyPanicIsHandled to reproduce and test the
  fix

fixes: #362
2025-10-25 20:40:05 -07:00
Benson Wong e250e71e59 Include metrics from upstream chat requests (#361)
* proxy: refactor metrics recording

- remove metrics_middleware.go as this wrapper is no longer needed. This
  also eliminiates double body parsing for the modelID
- move metrics parsing to be part of MetricsMonitor
- refactor how metrics are recording in ProxyManager
- add MetricsMonitor tests
- improve mem efficiency of processStreamingResponse
- add benchmarks for MetricsMonitor.addMetrics
- proxy: refactor MetricsMonitor to be more safe handling errors
2025-10-25 17:38:18 -07:00
Benson Wong d18dc26d01 cmd/wol-proxy: tweak logs to show what is causing wake ups (#356)
fix the extra wake ups being caused by wol-proxy

* cmd/wol-proxy: tweak logs to show what is causing wake ups
* cmd/wol-proxy: add skip wakeup
* cmd/wol-proxy: replace ticker with SSE connection
* cmd/wol-proxy: increase scanner buffer size
* cmd/wol-proxy: improve failure tracking
2025-10-25 11:04:31 -07:00
Benson Wong 8357714421 ui: fix avg token/sec calculation on models page (#357)
* ui: use percentiles for token stats
* ui: add histogram of metrics
* update vite to remove security warnings

fixes #355
2025-10-23 22:22:24 -07:00
Benson Wong c07179d6e2 cmd/wol-proxy: add wol-proxy (#352)
add a wake-on-lan proxy for llama-swap. When the target llama-swap server is unreachable it will send hold a request, send a WoL packet and proxy the request when llama-swap is available.
2025-10-20 20:55:02 -07:00
Benson Wong 7ff50631e0 Update README for setup instructions clarity [skip ci] 2025-10-19 14:55:23 -07:00
Benson Wong 9fc0431531 Clean up and Documentation (#347) [skip ci]
* cmd,misc: move misc binaries to cmd/
* docs: add docs and move examples/ there
* misc: remove unused misc/assets dir
* docs: add configuration.md
* update README with better structure

Updates: #334
2025-10-19 14:53:13 -07:00
David Wen Riccardi-Zhu 6516532568 Add optional TLS support (#340)
* Add optional TLS support

Introduce HTTPS support with net/http Server.ListenAndServeTLS.

This should enable the option of serving via HTTPS without a reverse
proxy.

Add two flags:
- tls-cert-file (path to the TLS certificate file)
- tls-key-file (path to the TLS private key file)

Both flags must be supplied together; otherwise exit with error.

If both flags are present, call srv.ListenAndServeTLS.
If not, fall back to the existing srv.ListenAndServe (HTTP); no changes
to existing non‑TLS behavior.
2025-10-15 19:29:02 -07:00
David Wen Riccardi-Zhu d58a8b85bf Refactor to use httputil.ReverseProxy (#342)
* Refactor to use httputil.ReverseProxy

Refactor manual HTTP proxying logic in Process.ProxyRequest to use the standard
library's httputil.ReverseProxy.

* Refactor TestProcess_ForceStopWithKill test

Update to handle behavior with httputil.ReverseProxy.

* Fix gin interface conversion panic
2025-10-13 16:47:04 -07:00
Benson Wong caf9e98b1e Fix race conditions in proxy.Process (#349)
- Fix data races found in proxy.Process by go's race detector. 
- Add data race detection to the CI tests. 

Fixes #348
2025-10-13 16:42:49 -07:00
Benson Wong 539278343b ui: tweak vertical space for mobile (#343) 2025-10-10 10:05:36 -07:00
Benson Wong 00b738cd0f Add Macro-In-Macro Support (#337)
Add full macro-in-macro support so any user defined macro can contain another one as long as it was previously declared in the configuration file.

Fixes #336
Supercedes #335
2025-10-06 22:57:15 -07:00
Benson Wong 70930e4e91 proxy: add support for user defined metadata in model configs (#333)
Changes: 

- add Metadata key to ModelConfig
- include metadata in /v1/models under meta.llamaswap key
- add recursive macro substitution into Metadata
- change macros at global and model level to be any scalar type

Note: 

This is the first mostly AI generated change to llama-swap. See #333 for notes about the workflow and approach to AI going forward.
2025-10-04 19:56:41 -07:00
Benson Wong 1f6179110c proxy/config: add model level macros (#330)
* proxy/config: add model level macros

Add macros to model configuration. Model macros override macros that are
defined at the global configuration level. They follow the same naming
and value rules as the global macros.

* proxy/config: fix bug with macro reserved name checking

The PORT reserved name was not properly checked

* proxy/config: add tests around model.filters.stripParams

- add check that model.filters.stripParams has no invalid macros
- renamed strip_params to stripParams for camel case consistency
- add legacy code compatibility so  model.filters.strip_params continues to work

* proxy/config: add duplicate removal to model.filters.stripParams

* clean up some doc nits
2025-09-28 23:32:52 -07:00
Benson Wong 216c40b951 proxy/config: create config package and migrate configuration (#329)
* proxy/config: create config package and migrate configuration

The configuration is become more complex as llama-swap adds more
advanced features. This commit moves config to its own package so it can
be developed independently of the proxy package.

Additionally, enforcing a public API for a configuration will allow
downstream usage to be more decoupled.
2025-09-28 16:50:06 -07:00
Benson Wong 9e3d491c85 proxyToUpstream: add redirect with trailing slash to upstream endpoint (#322)
This adds a redirect to the upstream endpoint so it always ends with a trailing /. 

Fixes #321
2025-09-25 16:43:00 -07:00
Benson Wong 1a84926505 proxy: add unload of single model (#318)
This adds a new API endpoint, /api/models/unload/*model, that unloads a single model. In the UI when a model is in a ReadyState it will have a new button to unload it. 

Fixes #312
2025-09-24 20:53:48 -07:00
Oleg Shulyakov fc3bb716df UI styling / code improvements (#307)
Clean up and improve UI styling

* fix: UI - dependency cleanup
* chore: UI - start script
* refactor: UI - Extract Header
* fix: UI - Header styling
* fix: UI - LogViewer styling
* fix: UI - Models styling
* fix: UI - Activity styling
* fix: UI - ConnectionStatus colors
* review: UI - table border colors
2025-09-19 10:47:17 -07:00
Benson Wong c36986fef6 upstream handler support for model names with forward slash (#298)
The upstream handler would break on model IDs that contained a forward
slash. Model IDs like "aaa/bbb" called at upstream/aaa/bbb would result
in an error. This commit adds support for model IDs with a forward slash
by iteratively searching the path for a match.

Fixes: #229
2025-09-13 13:37:03 -07:00
Artur Podsiadły 558801db1a Fix nginx proxy buffering for streaming endpoints (#295)
* Fix nginx proxy buffering for streaming endpoints

- Add X-Accel-Buffering: no header to SSE endpoints (/api/events, /logs/stream)
- Add X-Accel-Buffering: no header to proxied text/event-stream responses
- Add nginx reverse proxy configuration section to README
- Add tests for X-Accel-Buffering header on streaming endpoints

Fixes #236

* Fix goroutine cleanup in streaming endpoints test

Add context cancellation to TestProxyManager_StreamingEndpointsReturnNoBufferingHeader
to ensure the goroutine is properly cleaned up when the test completes.
2025-09-09 16:07:46 -07:00
Benson Wong b21dee27c1 Fix #288 Vite hot module reloading creating multiple SSE connections (#290)
- move SSE (EventSource) connection to module level
- manage EventSource as a singleton, closing open connection before
  reopening a new one
2025-09-07 21:48:58 -07:00
Benson Wong f58c8c8ec5 Support llama.cpp's cache_n in timings info (#287)
Capture prompt cache metrics and surface them on Activities page in UI
2025-09-06 13:58:02 -07:00
Benson Wong 954e2dee73 Remove cmdStart from README [skip ci]
cmdStart was in the README but it doesn't exist. Fixed the typo. Oops.
2025-09-04 11:57:28 -07:00
Benson Wong a533aec736 small tweak to example config 2025-09-01 21:26:58 -07:00
Brett Profitt 97b17fc47d Add ${MODEL_ID} macro (#226)
The automatic ${MODEL_ID} macro includes the name of the model and can be used in Cmd and CmdStop.
2025-09-01 21:21:37 -07:00
Benson Wong 2457840698 Update README.md [skip ci] 2025-08-28 23:44:37 -07:00
Benson Wong 7f55494151 Update README.md [skip ci] 2025-08-28 22:47:28 -07:00
Benson Wong 831a90d3b0 Add different timeout scenarios to Process.checkHealthEndpoint #276 (#278)
- add a TCP connection timeout of 500ms
- increase HTTP client timeout to 5000ms

In this new behaviour the upstream has 500ms to accept a tcp connection
and 5000ms to respond to the HTTP request.
2025-08-28 22:03:14 -07:00
Yandrik 977f1856bb add /completion endpoint (#275)
* feat: add /completion endpoint
* chore: reformat using gofmt
2025-08-28 21:41:02 -07:00
Benson Wong 52b329f7bc Fix #277 race condition in ProcessGroup.ProxyRequest when swap=true 2025-08-28 21:38:40 -07:00
77 changed files with 8724 additions and 1577 deletions
+7
View File
@@ -8,8 +8,15 @@ reviews:
poem: false
review_status: true
collapse_walkthrough: false
sequence_diagrams: false
finishing_touches:
docstrings:
enabled: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
issue_enrichment:
planning:
enabled: false
+41
View File
@@ -0,0 +1,41 @@
name: Validate JSON Schema
on:
pull_request:
paths:
- "config-schema.json"
push:
branches:
- main
paths:
- "config-schema.json"
workflow_dispatch:
jobs:
validate-schema:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Validate JSON Schema
run: |
# Check if the file is valid JSON
if ! jq empty config-schema.json 2>/dev/null; then
echo "Error: config-schema.json is not valid JSON"
exit 1
fi
# Validate that it's a valid JSON Schema
# Check for required $schema field
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
echo "Warning: config-schema.json should have a \$schema field"
fi
# Check that it has either properties or definitions
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
fi
echo "✓ config-schema.json is valid"
+43
View File
@@ -0,0 +1,43 @@
# Project: llama-swap
## Project Description:
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
## Tech stack
- golang
- typescript, vite and react for UI (ui/)
## Testing
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
## Workflow Tasks
### Plan Improvements
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
When the user asks to improve a plan follow these guidelines for expanding and improving it.
- Identify any inconsistencies.
- Expand plans out to be detailed specification of requirements and changes to be made.
- Plans should have at least these sections:
- Title - very short, describes changes
- Overview: A more detailed summary of goal and outcomes desired
- Design Requirements: Detailed descriptions of what needs to be done
- Testing Plan: Tests to be implemented
- Checklist: A detailed list of changes to be made
Look for "plan expansion" as explicit instructions to improve a plan.
### Implementation of plans
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
## General Rules
- when summarizing changes only include details that require further action (action items)
- when there are no action items, just say "Done."
+19 -7
View File
@@ -23,11 +23,17 @@ proxy/ui_dist/placeholder.txt:
mkdir -p proxy/ui_dist
touch $@
test: proxy/ui_dist/placeholder.txt
go test -short -v -count=1 ./proxy
# use cached test results while developing
test-dev: proxy/ui_dist/placeholder.txt
go test -short ./proxy/...
staticcheck ./proxy/... || true
test: proxy/ui_dist/placeholder.txt
go test -short -count=1 ./proxy/...
# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt
go test -v -count=1 ./proxy
go test -race -count=1 ./proxy/...
ui/node_modules:
cd ui && npm install
@@ -55,12 +61,12 @@ windows: ui
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
simple-responder-windows:
@echo "Building simple responder for windows"
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
@@ -80,5 +86,11 @@ release:
echo "tagging new version: $$new_tag"; \
git tag "$$new_tag";
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
# Phony targets
.PHONY: all clean ui mac linux windows simple-responder
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
+149 -117
View File
@@ -1,82 +1,51 @@
![llama-swap header image](header2.png)
![llama-swap header image](docs/assets/hero3.webp)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
## Features:
- ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
- ✅ On-demand model switching
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc.)
- future proof, upgrade your inference servers at any time.
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/embeddings`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
-llama-server (llama.cpp) supported endpoints:
-Anthropic API supported endpoints:
- `v1/messages`
- ✅ llama-server (llama.cpp) supported endpoints
- `v1/rerank`, `v1/reranking`, `/rerank`
- `/infill` - for code infilling
- ✅ llama-swap custom API endpoints
- `/completion` - for completion endpoint
- ✅ llama-swap API
- `/ui` - web UI
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring
- `/health` - just returns "OK"
-Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
-Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Reliable Docker and Podman support with `cmdStart` and `cmdStop`
- ✅ Full control over server settings per model
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
-API Key support - define keys to restrict access to API endpoints
-Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
## How does llama-swap work?
### Web UI
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
llama-swap includes a real time web interface for monitoring logs and controlling models:
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap is managed entirely through a yaml configuration file.
It can be very minimal to start:
```yaml
models:
"qwen2.5":
cmd: |
/path/to/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port ${PORT}
```
However, there are many more capabilities that llama-swap supports:
- `groups` to run multiple models at once
- `ttl` to automatically unload models
- `macros` for reusable snippets
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` for to gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `healthCheckTimeout` to control model startup wait times
- `${PORT}` automatic port variables for dynamic port assignment
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
## Web UI
llama-swap includes a real time web interface for monitoring logs and models:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
The Activity Page shows recent requests:
@@ -88,113 +57,173 @@ llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. From release binaries
4. From source
3. WinGet
4. From release binaries
5. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker images with llama-swap and llama-server are built nightly.
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
```shell
# use CPU inference comes with the example config above
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# run with a custom configuration and models directory
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# configuration hot reload supported with a
# directory volume mount
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
-v /path/to/config:/config \
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
```
<details>
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
<summary>
more examples
</summary>
```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
# pull latest images per platform
docker pull ghcr.io/mostlygeek/llama-swap:cpu
docker pull ghcr.io/mostlygeek/llama-swap:cuda
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
docker pull ghcr.io/mostlygeek/llama-swap:intel
docker pull ghcr.io/mostlygeek/llama-swap:musa
# tagged llama-swap, platform and llama-server version images
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
# non-root cuda
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
```
</details>
### Homebrew Install (macOS/Linux)
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
```shell
# Set up tap and install formula
brew tap mostlygeek/llama-swap
brew install llama-swap
# Run llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
### WinGet Install (Windows)
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
> [!NOTE]
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
```shell
# install
C:\> winget install llama-swap
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
# upgrade
C:\> winget upgrade llama-swap
```
### Pre-built Binaries
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
### Building from source
1. Build requires golang and nodejs for the user interface.
1. Building requires Go and Node.js (for UI).
1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
1. look in the `build/` subdirectory for the llama-swap binary
## Monitoring Logs
## Configuration
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
```yaml
# minimum viable config.yaml
CLI access is also supported:
models:
model1:
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
```
```shell
That's all you need to get started:
1. `models` - holds all model configurations
2. `model1` - the ID used in API calls
3. `cmd` - the command to run to start the server.
4. `${PORT}` - an automatically assigned port number
Almost all configuration settings are optional and can be added one step at a time:
- Advanced features
- `groups` to run multiple models at once
- `hooks` to run things on startup
- `macros` reusable snippets
- Model customization
- `ttl` to automatically unload models
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `${PORT}` automatic port variables for dynamic port assignment
- `filters` rewrite parts of requests before sending to the upstream server
See the [configuration documentation](docs/configuration.md) for all options.
## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## Reverse Proxy Configuration (nginx)
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks ServerSent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
Recommended nginx configuration snippets:
```nginx
# SSE for UI events/logs
location /api/events {
proxy_pass http://your-llama-swap-backend;
proxy_buffering off;
proxy_cache off;
}
# Streaming chat completions (stream=true)
location /v1/chat/completions {
proxy_pass http://your-llama-swap-backend;
proxy_buffering off;
proxy_cache off;
}
```
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
## Monitoring Logs on the CLI
```sh
# sends up to the last 10KB of logs
curl http://host/logs'
$ curl http://host/logs
# streams combined logs
curl -Ns 'http://host/logs/stream'
curl -Ns http://host/logs/stream
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# stream llama-swap's proxy status logs
curl -Ns http://host/logs/stream/proxy
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream logs from upstream processes that llama-swap loads
curl -Ns http://host/logs/stream/upstream
# stream logs only from a specific model
curl -Ns http://host/logs/stream/{model_id}
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
# skips history and just streams new log entries
# appending ?no-history will disable sending buffered history first
curl -Ns 'http://host/logs/stream?no-history'
```
@@ -202,8 +231,11 @@ curl -Ns 'http://host/logs/stream?no-history'
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
## Star History
> [!NOTE]
> ⭐️ Star this project to help others discover it!
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
@@ -0,0 +1,85 @@
# Replace ring.Ring with Efficient Circular Byte Buffer
## Overview
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
## Current Issues
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
4. Linked list structure has poor cache locality and pointer overhead
## Design Requirements
### New CircularBuffer Type
Create a simple circular byte buffer with:
- Single pre-allocated `[]byte` of fixed capacity (10KB)
- `head` and `size` integers to track write position and data length
- No per-write allocations
### API Requirements
The new buffer must support:
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
### Implementation Details
```go
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
```
**Write logic:**
- If `len(p) >= capacity`: just keep the last `capacity` bytes
- Otherwise: write bytes at `head`, wrapping around if needed
- Update `head` and `size` accordingly
- Data is copied into the internal buffer (not stored by reference)
**GetHistory logic:**
- Calculate start position: `(head - size + cap) % cap`
- If not wrapped: single slice copy
- If wrapped: two copies (end of buffer + beginning)
- Returns a **new slice** (copy), not a view into internal buffer
### Immutability Guarantees (must preserve)
Per existing tests:
1. Modifying input `[]byte` after `Write()` must not affect stored data
2. `GetHistory()` returns independent copy - modifications don't affect buffer
## Files to Modify
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
## Testing Plan
Existing tests in `logMonitor_test.go` should continue to pass:
- `TestLogMonitor` - Basic write/read and subscriber notification
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
- `TestWrite_LogTimeFormat` - Timestamp formatting
Add new tests:
- Test buffer wrap-around behavior
- Test large writes that exceed buffer capacity
- Test exact capacity boundary conditions
## Checklist
- [ ] Create `circularBuffer` struct in `logMonitor.go`
- [ ] Implement `Write()` method for circular buffer
- [ ] Implement `GetHistory()` method for circular buffer
- [ ] Update `LogMonitor` struct to use new buffer
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
- [ ] Update `LogMonitor.Write()` to use new buffer
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
- [ ] Remove `"container/ring"` import
- [ ] Run `make test-dev` to verify existing tests pass
- [ ] Add wrap-around test case
- [ ] Run `make test-all` for final validation
+292
View File
@@ -0,0 +1,292 @@
# Add Model Metadata Support with Typed Macros
## Overview
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
## Design Requirements
### 1. Enhanced Macro System
**Current State:**
- Macros are defined as `map[string]string` at both global and model levels
- Only string substitution is supported
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
**Required Changes:**
- Change `MacroList` type from `map[string]string` to `map[string]any`
- Support scalar types: `string`, `int`, `float64`, `bool`
- Implement type-preserving macro substitution:
- Direct macro usage (`key: ${macro}`) preserves the macro's type
- Interpolated usage (`key: "text ${macro}"`) converts to string
- Add validation to ensure macro values are scalar types only
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
**Implementation Details:**
- Create a generic helper function to perform macro substitution that:
- Takes a value of type `any`
- Recursively processes maps, slices, and scalar values
- Replaces `${macro_name}` patterns with macro values
- Preserves types for direct substitution
- Converts to strings for interpolated substitution
- Update `validateMacro()` function to accept `any` type and validate scalar types
- Maintain backward compatibility with existing string-only macros
### 2. Metadata Field in ModelConfig
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
**Required Changes:**
- Add `Metadata map[string]any` field to `ModelConfig` struct
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
- Apply macro substitution to metadata values during config loading
**Schema Requirements:**
- Metadata is optional (default: empty/nil map)
- Supports nested structures (objects within objects, arrays, etc.)
- All string values within metadata undergo macro substitution
- Type preservation rules apply as described above
### 3. Macro Substitution in Metadata
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
**Process Flow:**
1. After loading YAML configuration
2. After model-level and global macro merging
3. Apply macro substitution to `ModelConfig.Metadata` field
4. Use the same merged macros available to `cmd`, `proxy`, etc.
5. Process recursively through all nested structures
**Substitution Rules:**
- `port: ${PORT}` → keeps integer type from PORT macro
- `temperature: ${temp}` → keeps float type from temp macro
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
- Arrays and nested objects are processed recursively
- Unknown macros should cause configuration load error (consistent with existing behavior)
### 4. API Response Updates
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
**Current Behavior:**
- Returns model records with: `id`, `object`, `created`, `owned_by`
- Optionally includes: `name`, `description`
**Required Changes:**
- Add metadata to each model record under the key `llamaswap_meta`
- Only include `llamaswap_meta` if metadata is non-empty
- Preserve all types when marshaling to JSON
- Maintain existing sorting by model ID
**Example Response:**
```json
{
"object": "list",
"data": [
{
"id": "llama",
"object": "model",
"created": 1234567890,
"owned_by": "llama-swap",
"name": "llama 3.1 8B",
"description": "A small but capable model",
"llamaswap_meta": {
"port": 10001,
"temperature": 0.7,
"note": "The llama is running on port 10001 temp=0.7, context=16384",
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
"an_obj": {
"a": "1",
"b": 2,
"c": [0.7, false, "model: llama"]
}
}
}
]
}
```
### 5. Validation and Error Handling
**Macro Validation:**
- Extend `validateMacro()` to accept values of type `any`
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
- Reject complex types (maps, slices, structs) as macro values
- Maintain existing validation for macro names and lengths
**Configuration Loading:**
- Fail fast if unknown macros are found in metadata
- Provide clear error messages indicating which model and field contains errors
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
## Testing Plan
### Test 1: Model-Level Macros with Different Types
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
**Test Cases:**
- Define model with macros of each scalar type
- Verify metadata correctly substitutes and preserves types
- Test direct substitution (`port: ${PORT}`)
- Test string interpolation (`note: "Port is ${PORT}"`)
- Verify nested objects and arrays work correctly
### Test 2: Global and Model Macro Precedence
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Define same macro at global and model level with different types
- Verify model-level macro takes precedence
- Test metadata uses correct macro value
- Verify type is preserved from the winning macro
### Test 3: Macro Validation
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Test that complex types (maps, arrays) are rejected as macro values
- Verify error message includes: macro name and type that was rejected
- Test that scalar types (string, int, float, bool) are accepted
- Each type should load without error
- Test macro name validation still works with `any` types
- Invalid characters, reserved names, length limits should still be enforced
### Test 4: Metadata in API Response
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
**Existing Test:** `TestProxyManager_ListModelsHandler`
**Test Cases:**
- Model with metadata → verify `llamaswap_meta` key appears
- Model without metadata → verify `llamaswap_meta` key is absent
- Verify all types are correctly marshaled to JSON
- Verify nested structures are preserved
- Verify macro substitution has occurred before serialization
### Test 5: Unknown Macros in Metadata
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Use undefined macro in metadata
- Verify configuration loading fails with clear error
- Error should indicate model name and that macro is undefined
### Test 6: Recursive Substitution
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Metadata with deeply nested structures
- Arrays containing objects with macros
- Objects containing arrays with macros
- Mixed string interpolation and direct substitution at various nesting levels
## Checklist
### Configuration Schema Changes
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
- [x] Update `validateMacro()` function signature to accept `any` type for values
- [x] Add validation logic to ensure macro values are scalar types only
### Macro Substitution Logic
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
- [x] Implement type-preserving direct substitution logic
- [x] Implement string interpolation with type conversion
- [x] Handle maps: recursively process all values
- [x] Handle slices: recursively process all elements
- [x] Handle scalar types: perform string-based macro substitution if value is string
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
- [x] Update existing macro substitution calls to use merged macros with correct types
### API Response Changes
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
- [x] Add `llamaswap_meta` field to model records when metadata exists
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
- [x] Verify JSON marshaling preserves all types correctly
### Testing - Config Package
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
### Testing - Model Config Package
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
- [x] Test metadata with various scalar types
- [x] Test metadata with nested objects and arrays
### Testing - Proxy Manager
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
- [x] Add test case for model with metadata
- [x] Add test case for model without metadata
- [x] Verify `llamaswap_meta` key presence/absence
- [x] Verify type preservation in JSON output
- [x] Verify macro substitution has occurred
### Documentation
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
- [x] No additional documentation needed per project instructions
## Known Issues and Considerations
### Inconsistencies
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
### Design Decisions
1. **Why `llamaswap_meta` instead of merging into record?**
- Avoids potential collisions with OpenAI API standard fields
- Makes it clear this is llama-swap specific metadata
- Easier for clients to distinguish standard vs. custom fields
2. **Why support nested structures?**
- Provides maximum flexibility for users
- Aligns with the schemaless design principle
- Example config already demonstrates this capability
3. **Why validate macro types?**
- Prevents confusing behavior (e.g., substituting a map)
- Makes configuration errors explicit at load time
- Simpler implementation and testing
+397
View File
@@ -0,0 +1,397 @@
# Improve macro-in-macro support
**Status: COMPLETED ✅**
## Title
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
## Overview
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
**Outcomes:**
- Macros can reference other macros defined earlier in the config
- Macro substitution is deterministic and order-dependent
- Single-pass substitution prevents circular dependencies
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
- All existing tests pass
- New tests validate substitution order and self-reference detection
## Design Requirements
### 1. YAML Parsing Strategy
- **Continue using:** `gopkg.in/yaml.v3` (current library)
- **Use:** `yaml.Node` for ordered parsing of macros
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
### 2. Data Structure Changes
#### Current Implementation (config.go:19)
```go
type MacroList map[string]any
```
#### New Implementation
```go
type MacroList []MacroEntry
type MacroEntry struct {
Name string
Value any
}
```
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
### 3. Macro Substitution Order Rules
The substitution must follow this hierarchy (from most specific to least):
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
2. **Model-level macros** (middle): Defined in specific model config, overrides global
3. **Global macros** (first): Defined at config root level
Within each level, macros are substituted in **reverse definition order** (LIFO):
- The last macro defined is substituted first
- This allows later macros to reference earlier ones
- Single-pass substitution prevents circular dependencies
### 4. Macro Reference Rules
**Allowed:**
- Macro can reference any macro defined **before** it (earlier in the file)
- Model macros can reference global macros
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
**Prohibited:**
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
- Macro cannot reference macros defined **after** it
- No circular references (prevented by single-pass, ordered substitution)
### 5. Validation Requirements
Add validation to detect:
- **Self-references:** Macro value contains reference to its own name
- **Unknown macros:** After substitution, any remaining `${...}` references
Error messages should be clear:
```
macro 'foo' contains self-reference
unknown macro '${bar}' in model.cmd
```
### 6. Implementation Changes
#### Files to Modify
1. **[proxy/config/config.go](proxy/config/config.go)**
- Line 19: Change `MacroList` type definition
- Line 69: Update `Macros MacroList` field
- Line 153-157: Update macro validation loop to work with ordered structure
- Line 175-188: Update model-level macro validation
- Line 181-188: **NEW** Implement proper macro merging respecting order
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
- Line 389-415: Update `validateMacro` to detect self-references
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
- Line 33: Update `Macros MacroList` field type
3. **All test files**
- Update test fixtures to use ordered macro definitions
- Ensure tests specify macro order explicitly
#### Core Algorithm
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
```go
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
// Add global macros first
for _, entry := range config.Macros {
mergedMacros = append(mergedMacros, entry)
}
// Add model macros (can override global)
for _, entry := range modelConfig.Macros {
// Remove any existing global macro with same name
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry // Override
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
// Add reserved MODEL_ID macro at the end
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
// Check if PORT macro is needed
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// Add PORT macro to the end (highest priority)
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
nextPort++
}
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
// This allows later macros to reference earlier ones
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
// Substitute in command fields
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in metadata (recursive)
if len(modelConfig.Metadata) > 0 {
var err error
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
}
}
```
Add this new helper function to replace `substituteMetadataMacros`:
```go
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}
```
### 7. Self-Reference Detection
Add to `validateMacro` function:
```go
func validateMacro(name string, value any) error {
// ... existing validation ...
// Check for self-reference
if str, ok := value.(string); ok {
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(str, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
}
return nil
}
```
## Testing Plan
### 1. Migration Tests
- **Test:** All existing macro tests still pass after YAML library migration
- **Files:** All `*_test.go` files with macro tests
### 2. Macro Order Tests
#### Test: Macro-in-macro substitution order
```yaml
macros:
"A": "value-A"
"B": "prefix-${A}-suffix"
models:
test:
cmd: "echo ${B}"
```
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
#### Test: LIFO substitution order
```yaml
macros:
"base": "/models"
"path": "${base}/llama"
"full": "${path}/model.gguf"
models:
test:
cmd: "load ${full}"
```
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
#### Test: Model macro overrides global
```yaml
macros:
"tag": "global"
"msg": "value-${tag}"
models:
test:
macros:
"tag": "model-level"
cmd: "echo ${msg}"
```
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
### 3. Reserved Macro Tests
#### Test: MODEL_ID substituted in macro
```yaml
macros:
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
models:
my-model:
cmd: "${podman-llama} -m model.gguf"
```
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
### 4. Error Detection Tests
#### Test: Self-reference detection
```yaml
macros:
"recursive": "value-${recursive}"
```
**Expected:** Error: `macro 'recursive' contains self-reference`
#### Test: Undefined macro reference
```yaml
macros:
"A": "value-${UNDEFINED}"
```
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
### 5. Regression Tests
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
- Ensure all pass without modification (except test fixtures if needed)
## Checklist
### Phase 1: Data Structure Changes
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
- [ ] Implement helper functions:
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
### Phase 2: Macro Validation Updates
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
- [ ] Test self-reference detection with new test case
### Phase 3: Macro Substitution Algorithm
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
- [ ] Substitute in metadata within same loop
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
- [ ] Ensure `PORT` is added after port assignment (if needed)
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
### Phase 4: Testing
- [ ] Run `make test-dev` - fix any static checking errors
- [ ] Add test: macro-in-macro basic substitution
- [ ] Add test: LIFO substitution order with 3+ macro levels
- [ ] Add test: MODEL_ID in global macro used by model
- [ ] Add test: PORT in global macro used by model
- [ ] Add test: model macro overrides global macro in substitution
- [ ] Add test: self-reference detection error
- [ ] Add test: undefined macro reference error
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
### Phase 5: Documentation
- [ ] Update plan status in this file (mark completed)
- [ ] Update CLAUDE.md if macro behavior needs documentation
- [ ] Verify no new error messages need user documentation
## Bug Example (Original Issue)
```yaml
macros:
"podman-llama": >
podman run --name ${MODEL_ID}
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
ghcr.io/ggml-org/llama.cpp:server-cuda
"standard-options": >
--no-mmap --jinja
"kv8": >
-fa on -ctk q8_0 -ctv q8_0
```
**Current Bug:**
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
**After Fix:**
- Macros substituted in LIFO order: `kv8``standard-options``podman-llama`
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
@@ -153,6 +153,19 @@ func main() {
})
// llama-server compatibility: /completion
r.POST("/completion", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
})
})
// issue #41
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
// Parse the multipart form
+27
View File
@@ -0,0 +1,27 @@
# wol-proxy
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
## Usage
```shell
# minimal
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
# everything
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
# use debug log level
-log debug \
# altenerative listening port
-listen localhost:9999 \
# seconds to hold requests waiting for upstream to be ready
-timeout 30
```
## API
`GET /status` - that's it. Everything else is proxied to the upstream server.
+64
View File
@@ -0,0 +1,64 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Loading...</title>
<style>
body {
font-family: sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: #f5f5f5;
}
.loader {
text-align: center;
}
.stats {
font-size: 18px;
color: #333;
margin: 20px 0;
}
.stats-label {
color: #666;
font-size: 14px;
}
</style>
</head>
<body>
<div class="loader">
<p>Waking up upstream server...</p>
<div class="stats">
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
<div><span id="attempts">&nbsp;</span></div>
</div>
</div>
<script>
var startTime = Date.now();
var attempts = 0;
setInterval(function() {
var elapsed = (Date.now() - startTime) / 1000;
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
}, 100);
// Check status every second
setInterval(function() {
attempts++;
var dots = '.'.repeat((attempts % 10) || 10);
document.getElementById('attempts').textContent = dots;
fetch('/status')
.then(function(r) { return r.text(); })
.then(function(t) {
if (t.indexOf('status: ready') !== -1) {
location.reload();
}
})
.catch(function() {});
}, 1000);
</script>
</body>
</html>
+333
View File
@@ -0,0 +1,333 @@
package main
import (
"bufio"
"context"
_ "embed"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"time"
)
//go:embed index.html
var loadingPageHTML string
var (
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
flagListen = flag.String("listen", ":8080", "listen address to listen on")
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
)
func main() {
flag.Parse()
switch *flagLog {
case "debug":
slog.SetLogLoggerLevel(slog.LevelDebug)
case "info":
slog.SetLogLoggerLevel(slog.LevelInfo)
case "warn":
slog.SetLogLoggerLevel(slog.LevelWarn)
case "error":
slog.SetLogLoggerLevel(slog.LevelError)
default:
slog.Error("invalid log level", "logLevel", *flagLog)
return
}
// Validate flags
if *flagListen == "" {
slog.Error("listen address is required")
return
}
if *flagMac == "" {
slog.Error("mac address is required")
return
}
if *flagTimeout < 1 {
slog.Error("timeout must be greater than 0")
return
}
var upstreamURL *url.URL
var err error
// validate mac address
if _, err = net.ParseMAC(*flagMac); err != nil {
slog.Error("invalid mac address", "error", err)
return
}
if *flagUpstream == "" {
slog.Error("upstream proxy address is required")
return
} else {
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
if err != nil {
slog.Error("error parsing upstream url", "error", err)
return
}
}
proxy := newProxy(upstreamURL)
server := &http.Server{
Addr: *flagListen,
Handler: proxy,
}
// start the server
go func() {
slog.Info("server starting on", "address", *flagListen)
if err := server.ListenAndServe(); err != nil {
slog.Error("error starting server", "error", err)
}
}()
// graceful shutdown
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
<-ctx.Done()
server.Close()
}
type upstreamStatus string
const (
notready upstreamStatus = "not ready"
ready upstreamStatus = "ready"
)
type proxyServer struct {
upstreamProxy *httputil.ReverseProxy
failCount int
statusMutex sync.RWMutex
status upstreamStatus
}
func newProxy(url *url.URL) *proxyServer {
p := httputil.NewSingleHostReverseProxy(url)
proxy := &proxyServer{
upstreamProxy: p,
status: notready,
failCount: 0,
}
// start a goroutine to monitor upstream status via SSE
go func() {
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
client := &http.Client{
Timeout: 0, // No timeout for SSE connection
}
waitDuration := 10 * time.Second
for {
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
req, err := http.NewRequest("GET", eventsUrl, nil)
if err != nil {
slog.Warn("failed to create SSE request", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(waitDuration)
continue
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
resp, err := client.Do(req)
if err != nil {
slog.Error("failed to connect to SSE endpoint", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
if resp.StatusCode != http.StatusOK {
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
// Successfully connected to SSE endpoint
slog.Info("connected to SSE endpoint, upstream ready")
proxy.setStatus(ready)
proxy.resetFailures()
// Read from the SSE stream to detect disconnection
scanner := bufio.NewScanner(resp.Body)
// use a fairly large buffer to avoid scanner errors when reading large SSE events
buf := make([]byte, 0, 1024*1024*2)
scanner.Buffer(buf, 1024*1024*2)
events := 0
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
fmt.Print("Events: ")
}
for scanner.Scan() {
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
// Just read the events to keep connection alive
// We don't need to process the event data
events++
fmt.Printf("%d, ", events)
}
}
fmt.Println()
if err := scanner.Err(); err != nil {
slog.Error("error reading from SSE stream", "error", err)
}
// Connection closed or error occurred
_ = resp.Body.Close()
slog.Info("SSE connection closed, upstream not ready")
proxy.setStatus(notready)
proxy.incFail(1)
// Wait before reconnecting
time.Sleep(waitDuration)
}
}()
return proxy
}
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && r.URL.Path == "/status" {
status := string(p.getStatus())
failCount := p.getFailures()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
fmt.Fprintf(w, "status: %s\n", status)
fmt.Fprintf(w, "failures: %d\n", failCount)
return
}
if p.getStatus() == notready {
path := r.URL.Path
if strings.HasPrefix(path, "/api/events") {
slog.Debug("Skipping wake up", "req", path)
w.WriteHeader(http.StatusNoContent)
return
}
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
if err := sendMagicPacket(*flagMac); err != nil {
slog.Warn("failed to send magic WoL packet", "error", err)
}
// For root or UI path requests, return loading page with status polling
// the web page will do the polling and redirect when ready
if path == "/" || strings.HasPrefix(path, "/ui/") {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, loadingPageHTML)
return
}
ticker := time.NewTicker(250 * time.Millisecond)
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
defer cancel()
loop:
for {
select {
case <-timeout.Done():
slog.Info("timeout waiting for upstream to be ready")
http.Error(w, "timeout", http.StatusRequestTimeout)
return
case <-ticker.C:
if p.getStatus() == ready {
ticker.Stop()
break loop
}
}
}
}
p.upstreamProxy.ServeHTTP(w, r)
}
func (p *proxyServer) getStatus() upstreamStatus {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.status
}
func (p *proxyServer) setStatus(status upstreamStatus) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.status = status
}
func (p *proxyServer) incFail(num int) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount += num
}
func (p *proxyServer) getFailures() int {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.failCount
}
func (p *proxyServer) resetFailures() {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount = 0
}
func sendMagicPacket(macAddr string) error {
hwAddr, err := net.ParseMAC(macAddr)
if err != nil {
return err
}
if len(hwAddr) != 6 {
return errors.New("invalid MAC address")
}
// Create the magic packet.
packet := make([]byte, 102)
// Add 6 bytes of 0xFF.
for i := 0; i < 6; i++ {
packet[i] = 0xFF
}
// Repeat the MAC address 16 times.
for i := 1; i <= 16; i++ {
copy(packet[i*6:], hwAddr)
}
// Send the packet using UDP.
addr := net.UDPAddr{
IP: net.IPv4bcast,
Port: 9,
}
conn, err := net.DialUDP("udp", nil, &addr)
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write(packet)
return err
}
+330
View File
@@ -0,0 +1,330 @@
{
"$schema": "https://json-schema.org/draft-07/schema#",
"$id": "llama-swap-config-schema.json",
"title": "llama-swap configuration",
"description": "Configuration file for llama-swap",
"type": "object",
"required": [
"models"
],
"definitions": {
"macros": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string",
"minLength": 0,
"maxLength": 1024
},
{
"type": "number"
},
{
"type": "boolean"
}
]
},
"propertyNames": {
"type": "string",
"minLength": 1,
"maxLength": 64,
"pattern": "^[a-zA-Z0-9_-]+$",
"not": {
"enum": [
"PORT",
"MODEL_ID"
]
}
},
"default": {},
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
}
},
"properties": {
"healthCheckTimeout": {
"type": "integer",
"minimum": 15,
"default": 120,
"description": "Number of seconds to wait for a model to be ready to serve requests."
},
"logLevel": {
"type": "string",
"enum": [
"debug",
"info",
"warn",
"error"
],
"default": "info",
"description": "Sets the logging value. Valid values: debug, info, warn, error."
},
"logTimeFormat": {
"type": "string",
"enum": [
"",
"ansic",
"unixdate",
"rubydate",
"rfc822",
"rfc822z",
"rfc850",
"rfc1123",
"rfc1123z",
"rfc3339",
"rfc3339nano",
"kitchen",
"stamp",
"stampmilli",
"stampmicro",
"stampnano"
],
"default": "",
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
},
"metricsMaxInMemory": {
"type": "integer",
"default": 1000,
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
},
"startPort": {
"type": "integer",
"default": 5800,
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
},
"sendLoadingState": {
"type": "boolean",
"default": false,
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
},
"includeAliasesInList": {
"type": "boolean",
"default": false,
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
},
"macros": {
"$ref": "#/definitions/macros"
},
"models": {
"type": "object",
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
"additionalProperties": {
"type": "object",
"required": [
"cmd"
],
"properties": {
"macros": {
"$ref": "#/definitions/macros"
},
"cmd": {
"type": "string",
"minLength": 1,
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
},
"cmdStop": {
"type": "string",
"default": "",
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
},
"name": {
"type": "string",
"default": "",
"maxLength": 128,
"description": "Display name for the model. Used in v1/models API response."
},
"description": {
"type": "string",
"default": "",
"maxLength": 1024,
"description": "Description for the model. Used in v1/models API response."
},
"env": {
"type": "array",
"items": {
"type": "string",
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
},
"default": [],
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
},
"proxy": {
"type": "string",
"default": "http://localhost:${PORT}",
"format": "uri",
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
},
"aliases": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Alternative model names for this configuration. Must be unique globally."
},
"checkEndpoint": {
"type": "string",
"default": "/health",
"pattern": "^/.*$|^none$",
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
},
"ttl": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
},
"useModelName": {
"type": "string",
"default": "",
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
},
"filters": {
"type": "object",
"properties": {
"stripParams": {
"type": "string",
"default": "",
"pattern": "^[a-zA-Z0-9_, ]*$",
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
}
},
"additionalProperties": false,
"default": {},
"description": "Dictionary of filter settings. Only stripParams is supported."
},
"metadata": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
},
"concurrencyLimit": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
},
"sendLoadingState": {
"type": "boolean",
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
},
"unlisted": {
"type": "boolean",
"default": false,
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
}
}
}
},
"groups": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"members"
],
"properties": {
"swap": {
"type": "boolean",
"default": true,
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
},
"exclusive": {
"type": "boolean",
"default": true,
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
},
"persistent": {
"type": "boolean",
"default": false,
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
},
"members": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
}
}
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
},
"hooks": {
"type": "object",
"properties": {
"on_startup": {
"type": "object",
"properties": {
"preload": {
"type": "array",
"items": {
"type": "string"
},
"default": [],
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
}
},
"additionalProperties": false,
"description": "Actions to perform on startup. Only supported action is preload."
}
},
"additionalProperties": false,
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
},
"logToStdout": {
"type": "string",
"enum": [
"proxy",
"upstream",
"both",
"none"
],
"default": "proxy",
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
},
"apiKeys": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
},
"peers": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"proxy",
"models"
],
"properties": {
"proxy": {
"type": "string",
"format": "uri",
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
},
"apiKey": {
"type": "string",
"default": "",
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
},
"models": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"description": "A list of models served by the peer."
}
}
},
"default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
}
}
}
+143 -17
View File
@@ -1,3 +1,6 @@
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example
# -------------------------------------
#
@@ -23,6 +26,24 @@ healthCheckTimeout: 500
# - Valid log levels: debug, info, warn, error
logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
@@ -35,26 +56,69 @@ metricsMaxInMemory: 1000
# - it is automatically incremented for every model that uses it
startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
# - useful for reducing common configuration settings
# - macro names are strings and must be less than 64 characters
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
"default_ctx": 4096
# Example of macro-in-macro usage. macros can contain other macros
# but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}"
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
# - model settings have default values that are used if they are not defined here
# - below are examples of the various settings a model can have:
# - available model settings: env, cmd, cmdStop, proxy, aliases, checkEndpoint, ttl, unlisted
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"llama":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
# - model level macros follow the same rules as global macros
macros:
"default_ctx": 16384
"temp": 0.7
# cmd: the command to run to start the inference server.
# - required
# - it is just a string, similar to what you would run on the CLI
@@ -64,6 +128,8 @@ models:
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf
--ctx-size ${default_ctx}
--temperature ${temp}
# name: a display name for the model
# - optional, default: empty string
@@ -119,15 +185,39 @@ models:
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only strip_params is currently supported
# - only stripParams is currently supported
filters:
# strip_params: a comma separated list of parameters to remove from the request
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for server side enforcement of sampling parameters
# - the `model` parameter can never be removed
# - can be any JSON key in the request body
# - recommended to stick to sampling parameters
strip_params: "temperature, top_p, top_k"
stripParams: "temperature, top_p, top_k"
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
# - metadata is only passed through in /v1/models responses
metadata:
# port will remain an integer
port: ${PORT}
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
a_list:
- 1
- 1.23
- "macros are OK in list and dictionary types: ${MODEL_ID}"
an_obj:
a: "1"
b: 2
# objects can contain complex types with macro substitution
# becomes: c: [0.7, false, "model: llama"]
c: ["${temp}", false, "model: ${MODEL_ID}"]
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0
@@ -138,6 +228,10 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
@@ -148,12 +242,12 @@ models:
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker example:
# container run times like Docker and Podman can be used reliably with a
# a combination of cmd and cmdStop.
# container runtimes like Docker and Podman can be used reliably with
# a combination of cmd, cmdStop, and ${MODEL_ID}
"docker-llama":
proxy: "http://127.0.0.1:${PORT}"
cmd: |
docker run --name dockertest
docker run --name ${MODEL_ID}
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggml-org/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
@@ -167,7 +261,7 @@ models:
# - on POSIX systems: a SIGTERM signal is sent
# - on Windows, calls taskkill to stop the process
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop dockertest
cmdStop: docker stop ${MODEL_ID}
# groups: a dictionary of group settings
# - optional, default: empty dictionary
@@ -240,10 +334,42 @@ hooks:
# - optional, default: empty dictionary
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
- "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
apiKey: sk-your-openrouter-key
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
+46 -22
View File
@@ -20,36 +20,60 @@ if [[ -z "$GITHUB_TOKEN" ]]; then
exit 1
fi
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
# variable, this permits testing with forked llama.cpp repositories
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
# to enable easy container builds on forked repos
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
# cpu only containers just use the server tag
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
BASE_TAG=server-${LCPP_TAG}
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
BASE_TAG=server-${ARCH}-${LCPP_TAG}
fi
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
for CONTAINER_TYPE in non-root root; do
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
USER_UID=0
USER_GID=0
USER_HOME=/root
if [ "$CONTAINER_TYPE" == "non-root" ]; then
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
USER_UID=10001
USER_GID=10001
USER_HOME=/app
fi
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
done
+36 -8
View File
@@ -1,16 +1,44 @@
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
FROM ${BASE_IMAGE}:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
ARG LS_VER=170
ARG LS_REPO=mostlygeek/llama-swap
# Set default UID/GID arguments
ARG UID=10001
ARG GID=10001
ARG USER_HOME=/app
# Add user/group
ENV HOME=$USER_HOME
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
groupadd --system --gid $GID app; \
fi; \
useradd --system --uid $UID --gid $GID \
--home $USER_HOME app; \
fi
# Handle paths
RUN mkdir --parents $HOME /app
RUN chown --recursive $UID:$GID $HOME /app
# Switch user
USER $UID:$GID
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml
# Add /app to PATH
ENV PATH="/app:${PATH}"
RUN \
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]

Before

Width:  |  Height:  |  Size: 261 KiB

After

Width:  |  Height:  |  Size: 261 KiB

Before

Width:  |  Height:  |  Size: 351 KiB

After

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

+467
View File
@@ -0,0 +1,467 @@
# config.yaml
llama-swap is designed to be very simple: one binary, one configuration file.
## minimal viable config
```yaml
models:
model1:
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
```
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
```yaml
models:
model1:
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
model2:
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
model3:
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
```
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature.
## Advanced control with `cmd`
llama-swap is also about customizability. You can use any CLI flag available:
```yaml
models:
model1:
cmd: | # support for multi-line
llama-server --PORT ${PORT} -m /path/to/model.gguf
--ctx-size 8192
--jinja
--cache-type-k q8_0
--cache-type-v q8_0
```
## Support for any OpenAI API compatible server
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
```yaml
models:
"Q3-30B-CODER-VLLM":
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
# cmdStop provides a reliable way to stop containers
cmdStop: docker stop vllm-coder
cmd: |
docker run --init --rm --name vllm-coder
--runtime=nvidia --gpus '"device=2,3"'
--shm-size=16g
-v /mnt/nvme/vllm-cache:/root/.cache
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
vllm/vllm-openai:v0.10.0
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
--served-model-name "Q3-30B-CODER-VLLM"
--enable-expert-parallel
--swap-space 16
--max-num-seqs 512
--max-model-len 65536
--max-seq-len-to-capture 65536
--gpu-memory-utilization 0.9
--tensor-parallel-size 2
--trust-remote-code
```
## Many more features..
llama-swap supports many more features to customize how you want to manage your environment.
| Feature | Description |
| --------- | ---------------------------------------------- |
| `ttl` | automatic unloading of models after a timeout |
| `macros` | reusable snippets to use in configurations |
| `groups` | run multiple models at a time |
| `hooks` | event driven functionality |
| `env` | define environment variables per model |
| `aliases` | serve a model with different names |
| `filters` | modify requests before sending to the upstream |
| `...` | And many more tweaks |
## Full Configuration Example
> [!NOTE]
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
```yaml
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example
# -------------------------------------
#
# 💡 Tip - Use an LLM with this file!
# ====================================
# This example configuration is written to be LLM friendly. Try
# copying this file into an LLM and asking it to explain or generate
# sections for you.
# ====================================
# Usage notes:
# - Below are all the available configuration options for llama-swap.
# - Settings noted as "required" must be in your configuration file
# - Settings noted as "optional" can be omitted
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
# - optional, default: 120
# - minimum value is 15 seconds, anything less will be set to this value
healthCheckTimeout: 500
# logLevel: sets the logging value
# - optional, default: info
# - Valid log levels: debug, info, warn, error
logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
# - it is automatically incremented for every model that uses it
startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
# - useful for reducing common configuration settings
# - macro names are strings and must be less than 64 characters
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
"default_ctx": 4096
# Example of macro-in-macro usage. macros can contain other macros
# but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}"
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
# - model settings have default values that are used if they are not defined here
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"llama":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
# - model level macros follow the same rules as global macros
macros:
"default_ctx": 16384
"temp": 0.7
# cmd: the command to run to start the inference server.
# - required
# - it is just a string, similar to what you would run on the CLI
# - using `|` allows for comments in the command, these will be parsed out
# - macros can be used within cmd
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf
--ctx-size ${default_ctx}
--temperature ${temp}
# name: a display name for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B"
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing"
# env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array
# - each value is a single string
# - in the format: ENV_NAME=value
env:
- "CUDA_VISIBLE_DEVICES=0,1,2"
# proxy: the URL where llama-swap routes API requests
# - optional, default: http://localhost:${PORT}
# - if you used ${PORT} in cmd this can be omitted
# - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# checkEndpoint: URL path to check if the server is ready
# - optional, default: /health
# - endpoint is expected to return an HTTP 200 response
# - all requests wait until the endpoint is ready or fails
# - use "none" to skip endpoint health checking
checkEndpoint: /custom-endpoint
# ttl: automatically unload the model after ttl seconds
# - optional, default: 0
# - ttl values must be a value greater than 0
# - a value of 0 disables automatic unloading of the model
ttl: 60
# useModelName: override the model name that is sent to upstream server
# - optional, default: ""
# - useful for when the upstream server expects a specific model name that
# is different from the model's ID
useModelName: "qwen:qwq"
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only stripParams is currently supported
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for server side enforcement of sampling parameters
# - the `model` parameter can never be removed
# - can be any JSON key in the request body
# - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k"
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
# - metadata is only passed through in /v1/models responses
metadata:
# port will remain an integer
port: ${PORT}
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
a_list:
- 1
- 1.23
- "macros are OK in list and dictionary types: ${MODEL_ID}"
an_obj:
a: "1"
b: 2
# objects can contain complex types with macro substitution
# becomes: c: [0.7, false, "model: llama"]
c: ["${temp}", false, "model: ${MODEL_ID}"]
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0
# - useful for limiting the number of active parallel requests a model can process
# - must be set per model
# - any number greater than 0 will override the internal default value of 10
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
# - recommended to be omitted and the default used
concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
# - optional, default: false
# - unlisted models do not show up in /v1/models api requests
# - can be requested as normal through all apis
unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker example:
# container runtimes like Docker and Podman can be used reliably with
# a combination of cmd, cmdStop, and ${MODEL_ID}
"docker-llama":
proxy: "http://127.0.0.1:${PORT}"
cmd: |
docker run --name ${MODEL_ID}
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggml-org/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# cmdStop: command to run to stop the model gracefully
# - optional, default: ""
# - useful for stopping commands managed by another system
# - the upstream's process id is available in the ${PID} macro
#
# When empty, llama-swap has this default behaviour:
# - on POSIX systems: a SIGTERM signal is sent
# - on Windows, calls taskkill to stop the process
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provides advanced controls over model swapping behaviour
# - using groups some models can be kept loaded indefinitely, while others are swapped out
# - model IDs must be defined in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
#
# NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap: controls the model swapping behaviour in within the group
# - optional, default: true
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# exclusive: controls how the group affects other groups
# - optional, default: true
# - true: causes all other groups to unload when this group runs a model
# - false: does not affect other groups
exclusive: true
# members references the models defined above
# required
members:
- "llama"
- "qwen-unlisted"
# Example:
# - in group2 all models can run at the same time
# - when a different group is loaded it causes all running models in this group to unload
"group2":
swap: false
# exclusive: false does not unload other groups when a model in group2 is requested
# - the models in group2 will be loaded but will not unload any other groups
exclusive: false
members:
- "docker-llama"
- "modelA"
- "modelB"
# Example:
# - a persistent group, prevents other groups from unloading it
"forever":
# persistent: prevents over groups from unloading the models in this group
# - optional, default: false
# - does not affect individual model behaviour
persistent: true
# set swap/exclusive to false to prevent swapping inside the group
# and the unloading of other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
# - the only supported hook is on_startup
hooks:
# on_startup: a dictionary of actions to perform on startup
# - optional, default: empty dictionary
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
apiKey: sk-your-openrouter-key
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
```
+9
View File
@@ -0,0 +1,9 @@
## Container Security
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
+5 -5
View File
@@ -1,6 +1,6 @@
module github.com/mostlygeek/llama-swap
go 1.23.0
go 1.25.4
require (
github.com/billziss-gh/golib v0.2.0
@@ -37,9 +37,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+8 -8
View File
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
+40 -9
View File
@@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy"
"github.com/mostlygeek/llama-swap/proxy/config"
)
var (
@@ -27,7 +28,9 @@ var (
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
listenStr := flag.String("listen", "", "listen ip/port")
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
keyFile := flag.String("tls-key-file", "", "TLS key file")
showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
@@ -38,13 +41,13 @@ func main() {
os.Exit(0)
}
config, err := proxy.LoadConfig(*configPath)
conf, err := config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
if len(config.Profiles) > 0 {
if len(conf.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
@@ -54,6 +57,23 @@ func main() {
gin.SetMode(gin.ReleaseMode)
}
// Validate TLS flags.
var useTLS = (*certFile != "" && *keyFile != "")
if (*certFile != "" && *keyFile == "") ||
(*certFile == "" && *keyFile != "") {
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
os.Exit(1)
}
// Set default ports.
if *listenStr == "" {
defaultPort := ":8080"
if useTLS {
defaultPort = ":8443"
}
listenStr = &defaultPort
}
// Setup channels for server management
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1)
@@ -67,7 +87,7 @@ func main() {
// Support for watching config and reloading when it changes
reloadProxyManager := func() {
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
config, err = proxy.LoadConfig(*configPath)
conf, err = config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
return
@@ -75,7 +95,9 @@ func main() {
fmt.Println("Configuration Changed")
currentPM.Shutdown()
srv.Handler = proxy.New(config)
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
srv.Handler = newPM
fmt.Println("Configuration Reloaded")
// wait a few seconds and tell any UI to reload
@@ -85,12 +107,14 @@ func main() {
})
})
} else {
config, err = proxy.LoadConfig(*configPath)
conf, err = config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1)
}
srv.Handler = proxy.New(config)
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
srv.Handler = newPM
}
}
@@ -166,9 +190,16 @@ func main() {
}()
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
var err error
if useTLS {
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
err = srv.ListenAndServeTLS(*certFile, *keyFile)
} else {
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
log.Fatalf("Fatal server error: %v\n", err)
}
}()
Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

-454
View File
@@ -1,454 +0,0 @@
package proxy
import (
"fmt"
"io"
"os"
"regexp"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
)
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"strip_params"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" {
continue
}
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type HooksConfig struct {
OnStartup HookOnStartup `yaml:"on_startup"`
}
type HookOnStartup struct {
Preload []string `yaml:"preload"`
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros map[string]string `yaml:"macros"`
// map aliases to actual model IDs
aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
// hooks, see: #209
Hooks HooksConfig `yaml:"hooks"`
}
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
}
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
if realName, found := c.RealModelName(modelName); !found {
return ModelConfig{}, "", false
} else {
return c.Models[realName], realName, true
}
}
func LoadConfig(path string) (Config, error) {
file, err := os.Open(path)
if err != nil {
return Config{}, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
// default configuration values
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
MetricsMaxInMemory: 1000,
}
err = yaml.Unmarshal(data, &config)
if err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
// set a minimum of 15 seconds
config.HealthCheckTimeout = 15
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
/* check macro constraint rules:
- name must fit the regex ^[a-zA-Z0-9_-]+$
- names must be less than 64 characters (no reason, just cause)
- name can not be any reserved macros: PORT
- macro values must be less than 1024 characters
*/
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
for macroName, macroValue := range config.Macros {
if len(macroName) >= 64 {
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
}
if !macroNameRegex.MatchString(macroName) {
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
}
if len(macroValue) >= 1024 {
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
}
switch macroName {
case "PORT":
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
}
}
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// Strip comments from command fields before macro expansion
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros {
macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
}
// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
nextPortStr := strconv.Itoa(nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
nextPort++
}
// make sure there are no unknown macros that have not been replaced
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // this is ok, has to be replaced by process later
}
if _, exists := config.Macros[macroName]; !exists {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
}
config.Models[modelId] = modelConfig
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
// clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
}
func SanitizeCommand(cmdStr string) ([]string, error) {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
+647
View File
@@ -0,0 +1,647 @@
package config
import (
"fmt"
"io"
"net/url"
"os"
"regexp"
"runtime"
"sort"
"strings"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
)
const DEFAULT_GROUP_ID = "(default)"
const (
LogToStdoutProxy = "proxy"
LogToStdoutUpstream = "upstream"
LogToStdoutBoth = "both"
LogToStdoutNone = "none"
)
type MacroEntry struct {
Name string
Value any
}
type MacroList []MacroEntry
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("macros must be a mapping")
}
// yaml.Node.Content for a mapping contains alternating key/value nodes
entries := make([]MacroEntry, 0, len(value.Content)/2)
for i := 0; i < len(value.Content); i += 2 {
keyNode := value.Content[i]
valueNode := value.Content[i+1]
var name string
if err := keyNode.Decode(&name); err != nil {
return fmt.Errorf("failed to decode macro name: %w", err)
}
var val any
if err := valueNode.Decode(&val); err != nil {
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
}
entries = append(entries, MacroEntry{Name: name, Value: val})
}
*ml = entries
return nil
}
// Get retrieves a macro value by name
func (ml MacroList) Get(name string) (any, bool) {
for _, entry := range ml {
if entry.Name == name {
return entry.Value, true
}
}
return nil, false
}
// ToMap converts MacroList to a map (for backward compatibility if needed)
func (ml MacroList) ToMap() map[string]any {
result := make(map[string]any, len(ml))
for _, entry := range ml {
result[entry.Name] = entry.Value
}
return result
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
)
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type HooksConfig struct {
OnStartup HookOnStartup `yaml:"on_startup"`
}
type HookOnStartup struct {
Preload []string `yaml:"preload"`
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
LogTimeFormat string `yaml:"logTimeFormat"`
LogToStdout string `yaml:"logToStdout"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"`
// map aliases to actual model IDs
aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
// hooks, see: #209
Hooks HooksConfig `yaml:"hooks"`
// send loading state in reasoning
SendLoadingState bool `yaml:"sendLoadingState"`
// present aliases to /v1/models OpenAI API listing
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
// support API keys, see issue #433, #50, #251
RequiredAPIKeys []string `yaml:"apiKeys"`
// support remote peers, see issue #433, #296
Peers PeerDictionaryConfig `yaml:"peers"`
}
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
}
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
if realName, found := c.RealModelName(modelName); !found {
return ModelConfig{}, "", false
} else {
return c.Models[realName], realName, true
}
}
func LoadConfig(path string) (Config, error) {
file, err := os.Open(path)
if err != nil {
return Config{}, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
// default configuration values
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
MetricsMaxInMemory: 1000,
}
err = yaml.Unmarshal(data, &config)
if err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
// set a minimum of 15 seconds
config.HealthCheckTimeout = 15
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
switch config.LogToStdout {
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
default:
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
/* check macro constraint rules:
- name must fit the regex ^[a-zA-Z0-9_-]+$
- names must be less than 64 characters (no reason, just cause)
- name can not be any reserved macros: PORT, MODEL_ID
- macro values must be less than 1024 characters
*/
for _, macro := range config.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, err
}
}
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// Strip comments from command fields before macro expansion
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// validate model macros
for _, macro := range modelConfig.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
// Add global macros first
mergedMacros = append(mergedMacros, config.Macros...)
// Add model macros (can override global)
for _, entry := range modelConfig.Macros {
// Remove any existing global macro with same name
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry // Override
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
// This allows later macros to reference earlier ones
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
// Substitute in command fields
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in metadata (recursive)
if len(modelConfig.Metadata) > 0 {
var err error
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
}
// Final pass: check if PORT macro is needed after macro expansion
// ${PORT} is a resource on the local machine so a new port is only allocated
// if it is required in either cmd or proxy keys
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort { // either has it
if !cmdHasPort && proxyHasPort { // but both don't have it
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// Add PORT macro and substitute it
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
macroSlug := "${PORT}"
macroStr := fmt.Sprintf("%v", nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
// Substitute PORT in metadata
if len(modelConfig.Metadata) > 0 {
var err error
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
nextPort++
}
// make sure there are no unknown macros that have not been replaced
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // this is ok, has to be replaced by process later
}
// Reserved macros are always valid (they should have been substituted already)
if macroName == "PORT" || macroName == "MODEL_ID" {
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
}
// Any other macro is unknown
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
// Check for unknown macros in metadata
if len(modelConfig.Metadata) > 0 {
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
return Config{}, err
}
}
// Validate the proxy URL.
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf(
"model %s: invalid proxy URL: %w", modelId, err,
)
}
// if sendLoadingState is nil, set it to the global config value
// see #366
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState // copy it
modelConfig.SendLoadingState = &v
}
config.Models[modelId] = modelConfig
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
// clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
// check api keys validatity
for _, apikey := range config.RequiredAPIKeys {
if apikey == "" {
return Config{}, fmt.Errorf("empty api key found in apiKeys")
}
if strings.Contains(apikey, " ") {
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
}
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
}
func SanitizeCommand(cmdStr string) ([]string, error) {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
// validateMacro validates macro name and value constraints
func validateMacro(name string, value any) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
// Validate that value is a scalar type
switch v := value.(type) {
case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
// Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
}
switch name {
case "PORT", "MODEL_ID":
return fmt.Errorf("macro name '%s' is reserved", name)
}
return nil
}
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
func validateMetadataForUnknownMacros(value any, modelId string) error {
switch v := value.(type) {
case string:
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
for _, match := range matches {
macroName := match[1]
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
}
return nil
case map[string]any:
for _, val := range v {
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
return err
}
}
return nil
case []any:
for _, val := range v {
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
return err
}
}
return nil
default:
// Scalar types don't contain macros
return nil
}
}
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}
@@ -1,6 +1,6 @@
//go:build !windows
package proxy
package config
import (
"os"
@@ -58,6 +58,7 @@ models:
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
assert.Equal(t, "", config.LogTimeFormat)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
@@ -160,47 +161,56 @@ groups:
t.Fatalf("Failed to load config: %v", err)
}
modelLoadingState := false
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
},
Hooks: HooksConfig{
OnStartup: HookOnStartup{
Preload: []string{"model1", "model2"},
},
},
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
SendLoadingState: &modelLoadingState,
},
"model2": {
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
},
},
HealthCheckTimeout: 15,
@@ -1,4 +1,4 @@
package proxy
package config
import (
"slices"
@@ -65,18 +65,6 @@ models:
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_FindConfig(t *testing.T) {
// TODO?
@@ -207,30 +195,91 @@ macros:
argOne: "--arg1"
argTwo: "--arg2"
autoPort: "--port ${PORT}"
overriddenByModelMacro: failed
models:
model1:
macros:
overriddenByModelMacro: success
cmd: |
${svr-path} ${argTwo}
# the automatic ${PORT} is replaced
${autoPort}
${argOne}
--arg3 three
--overridden ${overriddenByModelMacro}
cmdStop: |
/path/to/stop.sh --port ${PORT} ${argTwo}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
if !assert.NoError(t, err) {
t.FailNow()
}
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
assert.NoError(t, err)
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
}
func TestConfig_MacroReservedNames(t *testing.T) {
tests := []struct {
name string
config string
expectedError string
}{
{
name: "global macro named PORT",
config: `
macros:
PORT: "1111"
`,
expectedError: "macro name 'PORT' is reserved",
},
{
name: "global macro named MODEL_ID",
config: `
macros:
MODEL_ID: model1
`,
expectedError: "macro name 'MODEL_ID' is reserved",
},
{
name: "model macro named PORT",
config: `
models:
model1:
macros:
PORT: 1111
`,
expectedError: "model model1: macro name 'PORT' is reserved",
},
{
name: "model macro named MODEL_ID",
config: `
models:
model1:
macros:
MODEL_ID: model1
`,
expectedError: "model model1: macro name 'MODEL_ID' is reserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.config))
assert.NotNil(t, err)
assert.Equal(t, tt.expectedError, err.Error())
})
}
}
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
tests := []struct {
name string
@@ -274,7 +323,7 @@ macros:
models:
model1:
cmd: "${svr-path} --port ${PORT}"
proxy: "http://localhost:${unknownMacro}"
proxy: "http://${unknownMacro}:${PORT}"
`,
},
{
@@ -288,6 +337,20 @@ models:
model1:
cmd: "${svr-path} --port ${PORT}"
checkEndpoint: "http://localhost:${unknownMacro}/health"
`,
},
{
name: "unknown macro in filters.stripParams",
field: "filters.stripParams",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
filters:
stripParams: "model,${unknownMacro}"
`,
},
}
@@ -295,38 +358,13 @@ models:
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
}
//t.Log(err)
})
}
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig, ok := config.Models["model1"]
if !assert.True(t, ok) {
t.FailNow()
}
// make sure `model` and enmpty strings are not in the list
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
}
func TestStripComments(t *testing.T) {
tests := []struct {
name string
@@ -440,3 +478,334 @@ models:
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
}
func TestConfig_MacroModelId(t *testing.T) {
content := `
startPort: 9000
macros:
"docker-llama": docker run --name ${MODEL_ID} -p ${PORT}:8080 docker_img
"docker-stop": docker stop ${MODEL_ID}
models:
model1:
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
model2:
cmd: ${docker-llama}
cmdStop: ${docker-stop}
author/model:F16:
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
cmdStop: stop
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
dockerStopMacro, found := config.Macros.Get("docker-stop")
assert.True(t, found)
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "docker run --name model2 -p 9002:8080 docker_img", strings.Join(sanitizedCmd2, " "))
sanitizedCmdStop, err := SanitizeCommand(config.Models["model2"].CmdStop)
assert.NoError(t, err)
assert.Equal(t, "docker stop model2", strings.Join(sanitizedCmdStop, " "))
sanitizedCmd3, err := SanitizeCommand(config.Models["author/model:F16"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
}
func TestConfig_TypedMacrosInMetadata(t *testing.T) {
content := `
startPort: 10000
macros:
PORT_NUM: 10001
TEMP: 0.7
ENABLED: true
NAME: "llama model"
CTX: 16384
models:
test-model:
cmd: /path/to/server -p ${PORT}
metadata:
port: ${PORT_NUM}
temperature: ${TEMP}
enabled: ${ENABLED}
model_name: ${NAME}
context: ${CTX}
note: "Running on port ${PORT_NUM} with temp ${TEMP} and context ${CTX}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
meta := config.Models["test-model"].Metadata
assert.NotNil(t, meta)
// Verify direct substitution preserves types
assert.Equal(t, 10001, meta["port"])
assert.Equal(t, 0.7, meta["temperature"])
assert.Equal(t, true, meta["enabled"])
assert.Equal(t, "llama model", meta["model_name"])
assert.Equal(t, 16384, meta["context"])
// Verify string interpolation converts to string
assert.Equal(t, "Running on port 10001 with temp 0.7 and context 16384", meta["note"])
}
func TestConfig_NestedStructuresInMetadata(t *testing.T) {
content := `
startPort: 10000
macros:
TEMP: 0.7
models:
test-model:
cmd: /path/to/server -p ${PORT}
metadata:
config:
port: ${PORT}
temperature: ${TEMP}
tags: ["model:${MODEL_ID}", "port:${PORT}"]
nested:
deep:
value: ${TEMP}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
meta := config.Models["test-model"].Metadata
assert.NotNil(t, meta)
// Verify nested objects
configMap := meta["config"].(map[string]any)
assert.Equal(t, 10000, configMap["port"])
assert.Equal(t, 0.7, configMap["temperature"])
// Verify arrays
tags := meta["tags"].([]any)
assert.Equal(t, "model:test-model", tags[0])
assert.Equal(t, "port:10000", tags[1])
// Verify deeply nested structures
nested := meta["nested"].(map[string]any)
deep := nested["deep"].(map[string]any)
assert.Equal(t, 0.7, deep["value"])
}
func TestConfig_ModelLevelMacroPrecedenceInMetadata(t *testing.T) {
content := `
startPort: 10000
macros:
TEMP: 0.5
GLOBAL_VAL: "global"
models:
test-model:
cmd: /path/to/server -p ${PORT}
macros:
TEMP: 0.9
LOCAL_VAL: "local"
metadata:
temperature: ${TEMP}
global: ${GLOBAL_VAL}
local: ${LOCAL_VAL}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
meta := config.Models["test-model"].Metadata
assert.NotNil(t, meta)
// Model-level macro should override global
assert.Equal(t, 0.9, meta["temperature"])
// Global macro should be accessible
assert.Equal(t, "global", meta["global"])
// Model-level macro should be accessible
assert.Equal(t, "local", meta["local"])
}
func TestConfig_UnknownMacroInMetadata(t *testing.T) {
content := `
startPort: 10000
models:
test-model:
cmd: /path/to/server -p ${PORT}
metadata:
value: ${UNKNOWN_MACRO}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "test-model")
assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
}
func TestConfig_InvalidMacroType(t *testing.T) {
content := `
startPort: 10000
macros:
INVALID:
nested: value
models:
test-model:
cmd: /path/to/server -p ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "INVALID")
assert.Contains(t, err.Error(), "must be a scalar type")
}
func TestConfig_MacroTypeValidation(t *testing.T) {
tests := []struct {
name string
yaml string
shouldErr bool
}{
{
name: "string macro",
yaml: `
startPort: 10000
macros:
STR: "test"
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "int macro",
yaml: `
startPort: 10000
macros:
NUM: 42
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "float macro",
yaml: `
startPort: 10000
macros:
FLOAT: 3.14
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "bool macro",
yaml: `
startPort: 10000
macros:
BOOL: true
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "array macro (invalid)",
yaml: `
startPort: 10000
macros:
ARR: [1, 2, 3]
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: true,
},
{
name: "map macro (invalid)",
yaml: `
startPort: 10000
macros:
MAP:
key: value
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.yaml))
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestConfig_APIKeys_Invalid(t *testing.T) {
tests := []struct {
name string
content string
expectedErr string
}{
{
name: "empty string",
content: `apiKeys: [""]`,
expectedErr: "empty api key found in apiKeys",
},
{
name: "blank spaces only",
content: `apiKeys: [" "]`,
expectedErr: "api key cannot contain spaces: ` `",
},
{
name: "contains leading space",
content: `apiKeys: [" key123"]`,
expectedErr: "api key cannot contain spaces: ` key123`",
},
{
name: "contains trailing space",
content: `apiKeys: ["key123 "]`,
expectedErr: "api key cannot contain spaces: `key123 `",
},
{
name: "contains middle space",
content: `apiKeys: ["key 123"]`,
expectedErr: "api key cannot contain spaces: `key 123`",
},
{
name: "empty in list with valid keys",
content: `apiKeys: ["valid-key", "", "another-key"]`,
expectedErr: "empty api key found in apiKeys",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
if assert.Error(t, err) {
assert.Equal(t, tt.expectedErr, err.Error())
}
})
}
}
@@ -1,6 +1,6 @@
//go:build windows
package proxy
package config
import (
"os"
@@ -55,6 +55,7 @@ models:
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
assert.Equal(t, "", config.LogTimeFormat)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
@@ -152,44 +153,53 @@ groups:
t.Fatalf("Failed to load config: %v", err)
}
modelLoadingState := false
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
},
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
},
"model2": {
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
},
},
HealthCheckTimeout: 15,
+123
View File
@@ -0,0 +1,123 @@
package config
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// Test macro-in-macro basic substitution
func TestConfig_MacroInMacroBasic(t *testing.T) {
content := `
startPort: 10000
macros:
"A": "value-A"
"B": "prefix-${A}-suffix"
models:
test:
cmd: echo ${B}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
}
// Test LIFO substitution order with 3+ macro levels
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
content := `
startPort: 10000
macros:
"base": "/models"
"path": "${base}/llama"
"full": "${path}/model.gguf"
models:
test:
cmd: load ${full}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
}
// Test MODEL_ID in global macro used by model
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
content := `
startPort: 10000
macros:
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
models:
my-model:
cmd: ${podman-llama} -m model.gguf
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
}
// Test model macro overrides global macro in substitution
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
content := `
startPort: 10000
macros:
"tag": "global"
"msg": "value-${tag}"
models:
test:
macros:
"tag": "model-level"
cmd: echo ${msg}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
}
// Test self-reference detection error
func TestConfig_SelfReferenceDetection(t *testing.T) {
content := `
startPort: 10000
macros:
"recursive": "value-${recursive}"
models:
test:
cmd: echo ${recursive}
proxy: http://localhost:8080
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "recursive")
assert.Contains(t, err.Error(), "self-reference")
}
// Test undefined macro reference error
func TestConfig_UndefinedMacroReference(t *testing.T) {
content := `
startPort: 10000
macros:
"A": "value-${UNDEFINED}"
models:
test:
cmd: echo ${A}
proxy: http://localhost:8080
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "UNDEFINED")
}
+128
View File
@@ -0,0 +1,128 @@
package config
import (
"errors"
"runtime"
"slices"
"strings"
)
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
// Macros: see #264
// Model level macros take precedence over the global macros
Macros MacroList `yaml:"macros"`
// Metadata: see #264
// Arbitrary metadata that can be exposed through the API
Metadata map[string]any `yaml:"metadata"`
// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"stripParams"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
// Try to unmarshal with the old field name for backwards compatibility
if defaults.StripParams == "" {
var legacy struct {
StripParams string `yaml:"strip_params"`
}
if legacyErr := unmarshal(&legacy); legacyErr != nil {
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
}
defaults.StripParams = legacy.StripParams
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
+74
View File
@@ -0,0 +1,74 @@
package config
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
# macros inserted and list is cleaned of duplicates and empty strings
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
# check for strip_params (legacy field name) compatibility
legacy:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
for modelId, modelConfig := range config.Models {
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
// model has been removed
// empty strings have been removed
// duplicates have been removed
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
})
}
}
func TestConfig_ModelSendLoadingState(t *testing.T) {
content := `
sendLoadingState: true
models:
model1:
cmd: path/to/cmd --port ${PORT}
sendLoadingState: false
model2:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.True(t, config.SendLoadingState)
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
assert.False(t, *config.Models["model1"].SendLoadingState)
}
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
assert.True(t, *config.Models["model2"].SendLoadingState)
}
}
+47
View File
@@ -0,0 +1,47 @@
package config
import (
"fmt"
"net/url"
)
type PeerDictionaryConfig map[string]PeerConfig
type PeerConfig struct {
Proxy string `yaml:"proxy"`
ProxyURL *url.URL `yaml:"-"`
ApiKey string `yaml:"apiKey"`
Models []string `yaml:"models"`
}
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPeerConfig PeerConfig
defaults := rawPeerConfig{
Proxy: "",
ApiKey: "",
Models: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
// Validate proxy is not empty
if defaults.Proxy == "" {
return fmt.Errorf("proxy is required")
}
// Validate proxy is a valid URL and store the parsed value
parsedURL, err := url.Parse(defaults.Proxy)
if err != nil {
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
}
defaults.ProxyURL = parsedURL
// Validate models is not empty
if len(defaults.Models) == 0 {
return fmt.Errorf("peer models can not be empty")
}
*c = PeerConfig(defaults)
return nil
}
+139
View File
@@ -0,0 +1,139 @@
package config
import (
"testing"
"gopkg.in/yaml.v3"
)
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
tests := []struct {
name string
yaml string
wantErr string
}{
{
name: "valid config",
yaml: `
proxy: http://192.168.1.23
models:
- model_a
- model_b
`,
wantErr: "",
},
{
name: "valid config with apiKey",
yaml: `
proxy: https://openrouter.ai/api
apiKey: sk-test-key
models:
- meta-llama/llama-3.1-8b-instruct
`,
wantErr: "",
},
{
name: "missing proxy",
yaml: `
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "empty proxy",
yaml: `
proxy: ""
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "invalid proxy URL",
yaml: `
proxy: "://invalid"
models:
- model_a
`,
wantErr: "invalid peer proxy URL",
},
{
name: "missing models",
yaml: `
proxy: http://localhost:8080
`,
wantErr: "peer models can not be empty",
},
{
name: "empty models",
yaml: `
proxy: http://localhost:8080
models: []
`,
wantErr: "peer models can not be empty",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var config PeerConfig
err := yaml.Unmarshal([]byte(tt.yaml), &config)
if tt.wantErr == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.wantErr)
} else if !contains(err.Error(), tt.wantErr) {
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
}
}
})
}
}
func TestPeerConfig_ProxyURL(t *testing.T) {
yamlData := `
proxy: http://192.168.1.23:8080/api
apiKey: sk-test
models:
- model_a
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.ProxyURL == nil {
t.Fatal("ProxyURL should not be nil")
}
if config.ProxyURL.Host != "192.168.1.23:8080" {
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
}
if config.ProxyURL.Scheme != "http" {
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
}
if config.ProxyURL.Path != "/api" {
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchSubstring(s, substr)
}
func searchSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
+4 -3
View File
@@ -9,6 +9,7 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy/config"
"gopkg.in/yaml.v3"
)
@@ -65,18 +66,18 @@ func getTestPort() int {
return port
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
// Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d"
`, simpleResponderPath, port, expectedMessage, port)
var cfg ModelConfig
var cfg config.ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
}
+121 -21
View File
@@ -1,16 +1,95 @@
package proxy
import (
"container/ring"
"context"
"fmt"
"io"
"os"
"sync"
"time"
"github.com/mostlygeek/llama-swap/event"
)
// circularBuffer is a fixed-size circular byte buffer that overwrites
// oldest data when full. It provides O(1) writes and O(n) reads.
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
func newCircularBuffer(capacity int) *circularBuffer {
return &circularBuffer{
data: make([]byte, capacity),
head: 0,
size: 0,
}
}
// Write appends bytes to the buffer, overwriting oldest data when full.
// Data is copied into the internal buffer (not stored by reference).
func (cb *circularBuffer) Write(p []byte) {
if len(p) == 0 {
return
}
cap := len(cb.data)
// If input is larger than capacity, only keep the last cap bytes
if len(p) >= cap {
copy(cb.data, p[len(p)-cap:])
cb.head = 0
cb.size = cap
return
}
// Calculate how much space is available from head to end of buffer
firstPart := cap - cb.head
if firstPart >= len(p) {
// All data fits without wrapping
copy(cb.data[cb.head:], p)
cb.head = (cb.head + len(p)) % cap
} else {
// Data wraps around
copy(cb.data[cb.head:], p[:firstPart])
copy(cb.data[:len(p)-firstPart], p[firstPart:])
cb.head = len(p) - firstPart
}
// Update size
cb.size += len(p)
if cb.size > cap {
cb.size = cap
}
}
// GetHistory returns all buffered data in correct order (oldest to newest).
// Returns a new slice (copy), not a view into internal buffer.
func (cb *circularBuffer) GetHistory() []byte {
if cb.size == 0 {
return nil
}
result := make([]byte, cb.size)
cap := len(cb.data)
// Calculate start position (oldest data)
start := (cb.head - cb.size + cap) % cap
if start+cb.size <= cap {
// Data is contiguous, single copy
copy(result, cb.data[start:start+cb.size])
} else {
// Data wraps around, two copies
firstPart := cap - start
copy(result[:firstPart], cb.data[start:])
copy(result[firstPart:], cb.data[:cb.size-firstPart])
}
return result
}
type LogLevel int
const (
@@ -18,12 +97,14 @@ const (
LevelInfo
LevelWarn
LevelError
LogBufferSize = 100 * 1024
)
type LogMonitor struct {
eventbus *event.Dispatcher
mu sync.RWMutex
buffer *ring.Ring
buffer *circularBuffer
bufferMu sync.RWMutex
// typically this can be os.Stdout
@@ -32,6 +113,9 @@ type LogMonitor struct {
// logging levels
level LogLevel
prefix string
// timestamps
timeFormat string
}
func NewLogMonitor() *LogMonitor {
@@ -40,11 +124,12 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
eventbus: event.NewDispatcherConfig(1000),
buffer: nil, // lazy initialized on first Write
stdout: stdout,
level: LevelInfo,
prefix: "",
timeFormat: "",
}
}
@@ -59,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
}
w.bufferMu.Lock()
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.buffer.Value = bufferCopy
w.buffer = w.buffer.Next()
if w.buffer == nil {
w.buffer = newCircularBuffer(LogBufferSize)
}
w.buffer.Write(p)
w.bufferMu.Unlock()
// Make a copy for broadcast to preserve immutability
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.broadcast(bufferCopy)
return n, nil
}
@@ -72,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
func (w *LogMonitor) GetHistory() []byte {
w.bufferMu.RLock()
defer w.bufferMu.RUnlock()
if w.buffer == nil {
return nil
}
return w.buffer.GetHistory()
}
var history []byte
w.buffer.Do(func(p any) {
if p != nil {
if content, ok := p.([]byte); ok {
history = append(history, content...)
}
}
})
return history
// Clear releases the buffer memory, making it eligible for GC.
// The buffer will be lazily re-allocated on the next Write.
func (w *LogMonitor) Clear() {
w.bufferMu.Lock()
w.buffer = nil
w.bufferMu.Unlock()
}
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
@@ -106,12 +196,22 @@ func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.level = level
}
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
w.mu.Lock()
defer w.mu.Unlock()
w.timeFormat = timeFormat
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
timestamp := ""
if w.timeFormat != "" {
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
}
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
+230
View File
@@ -3,8 +3,10 @@ package proxy
import (
"bytes"
"io"
"strings"
"sync"
"testing"
"time"
)
func TestLogMonitor(t *testing.T) {
@@ -84,3 +86,231 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
t.Errorf("Expected history to be %q, got %q", expected, history)
}
}
func TestWrite_LogTimeFormat(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
// Enable timestamps
lm.timeFormat = time.RFC3339
// Write the message to the LogMonitor
lm.Info("Hello, World!")
// Get the history from the LogMonitor
history := lm.GetHistory()
timestamp := ""
fields := strings.Fields(string(history))
if len(fields) > 0 {
timestamp = fields[0]
} else {
t.Fatalf("Cannot extract string from history")
}
_, err := time.Parse(time.RFC3339, timestamp)
if err != nil {
t.Fatalf("Cannot find timestamp: %v", err)
}
}
func TestCircularBuffer_WrapAround(t *testing.T) {
// Create a small buffer to test wrap-around
cb := newCircularBuffer(10)
// Write "hello" (5 bytes)
cb.Write([]byte("hello"))
if got := string(cb.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Write "world" (5 bytes) - buffer now full
cb.Write([]byte("world"))
if got := string(cb.GetHistory()); got != "helloworld" {
t.Errorf("Expected 'helloworld', got %q", got)
}
// Write "12345" (5 bytes) - should overwrite "hello"
cb.Write([]byte("12345"))
if got := string(cb.GetHistory()); got != "world12345" {
t.Errorf("Expected 'world12345', got %q", got)
}
// Write data larger than buffer capacity
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
if got := string(cb.GetHistory()); got != "ghijklmnop" {
t.Errorf("Expected 'ghijklmnop', got %q", got)
}
}
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
// Test empty buffer
cb := newCircularBuffer(10)
if got := cb.GetHistory(); got != nil {
t.Errorf("Expected nil for empty buffer, got %q", got)
}
// Test exact capacity
cb.Write([]byte("1234567890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
// Test write exactly at capacity boundary
cb = newCircularBuffer(10)
cb.Write([]byte("12345"))
cb.Write([]byte("67890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
}
func TestLogMonitor_LazyInit(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Buffer should be nil before any writes
if lm.buffer != nil {
t.Error("Expected buffer to be nil before first write")
}
// GetHistory should return nil when buffer is nil
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history before first write, got %q", got)
}
// Write should lazily initialize the buffer
lm.Write([]byte("test"))
if lm.buffer == nil {
t.Error("Expected buffer to be initialized after write")
}
if got := string(lm.GetHistory()); got != "test" {
t.Errorf("Expected 'test', got %q", got)
}
}
func TestLogMonitor_Clear(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write some data
lm.Write([]byte("hello"))
if got := string(lm.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Clear should release the buffer
lm.Clear()
if lm.buffer != nil {
t.Error("Expected buffer to be nil after Clear")
}
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history after Clear, got %q", got)
}
}
func TestLogMonitor_ClearAndReuse(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write, clear, then write again
lm.Write([]byte("first"))
lm.Clear()
lm.Write([]byte("second"))
if got := string(lm.GetHistory()); got != "second" {
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
}
}
func BenchmarkLogMonitorWrite(b *testing.B) {
// Test data of varying sizes
smallMsg := []byte("small message\n")
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
b.Run("SmallWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(smallMsg)
}
})
b.Run("MediumWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("LargeWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(largeMsg)
}
})
b.Run("WithSubscribers", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Add some subscribers
for i := 0; i < 5; i++ {
lm.OnLogData(func(data []byte) {})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("GetHistory", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Pre-populate with data
for i := 0; i < 1000; i++ {
lm.Write(mediumMsg)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.GetHistory()
}
})
}
/*
Benchmark Results - MBP M1 Pro
Before (ring.Ring):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
| MediumWrite (241B) | 76 ns | 264 B | 2 |
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
After (circularBuffer 10KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 67 ns | 240 B | 1 |
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
After (circularBuffer 100KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|-----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 66 ns | 240 B | 1 |
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
Summary:
- GetHistory: 139x faster (10KB), 18x faster (100KB)
- Allocations: reduced from 2 to 1 across all operations
- Small/medium writes: ~1.1-1.6x faster
*/
-179
View File
@@ -1,179 +0,0 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
// isStreaming bool
startTime time.Time
}
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
// check for streaming response
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
writer.metricsRecorder.processStreamingResponse(writer.body)
} else {
writer.metricsRecorder.processNonStreamingResponse(writer.body)
}
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return false
}
// default values
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
+291 -15
View File
@@ -1,11 +1,20 @@
package proxy
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/tidwall/gjson"
)
// TokenMetrics represents parsed token statistics from llama-server logs
@@ -13,6 +22,7 @@ type TokenMetrics struct {
ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
CachedTokens int `json:"cache_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
PromptPerSecond float64 `json:"prompt_per_second"`
@@ -29,21 +39,18 @@ func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
}
// MetricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct {
// metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
maxMetrics int
nextID int
logger *LogMonitor
}
func NewMetricsMonitor(config *Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
}
mp := &MetricsMonitor{
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
mp := &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
}
@@ -51,7 +58,7 @@ func NewMetricsMonitor(config *Config) *MetricsMonitor {
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
mp.mu.Lock()
defer mp.mu.Unlock()
@@ -61,12 +68,11 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
if len(mp.metrics) > mp.maxMetrics {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
}
event.Emit(TokenMetricsEvent{Metrics: metric})
}
// GetMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
// getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
mp.mu.RLock()
defer mp.mu.RUnlock()
@@ -75,9 +81,279 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
return result
}
// GetMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
// getMetricsJSON returns metrics as JSON
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
}
// wrapHandler wraps the proxy handler to extract token metrics
// if wrapHandler returns an error it is safe to assume that no
// data was sent to the client
func (mp *metricsMonitor) wrapHandler(
modelID string,
writer gin.ResponseWriter,
request *http.Request,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
recorder := newBodyCopier(writer)
// Filter Accept-Encoding to only include encodings we can decompress for metrics
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
}
if err := next(modelID, recorder, request); err != nil {
return err
}
// after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
return nil
}
// Initialize default metrics - these will always be recorded
tm := TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm)
return nil
}
// Decompress if needed
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
var err error
body, err = decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm)
return nil
}
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsed
}
} else {
if gjson.ValidBytes(body) {
parsed := gjson.ParseBytes(body)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
if usage.Exists() || timings.Exists() {
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsedMetrics
}
}
} else {
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
}
}
mp.addMetrics(tm)
return nil
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.
// Start from the end of the body and scan backwards for newlines
pos := len(body)
for pos > 0 {
// Find the previous newline (or start of body)
lineStart := bytes.LastIndexByte(body[:pos], '\n')
if lineStart == -1 {
lineStart = 0
} else {
lineStart++ // Move past the newline
}
line := bytes.TrimSpace(body[lineStart:pos])
pos = lineStart - 1 // Move position before the newline for next iteration
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
parsed := gjson.ParseBytes(data)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
if usage.Exists() || timings.Exists() {
return parseMetrics(modelID, start, usage, timings)
}
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(start).Milliseconds())
if usage.Exists() {
if pt := usage.Get("prompt_tokens"); pt.Exists() {
// v1/chat/completions
inputTokens = int(pt.Int())
} else if it := usage.Get("input_tokens"); it.Exists() {
// v1/messages
inputTokens = int(it.Int())
}
if ct := usage.Get("completion_tokens"); ct.Exists() {
// v1/chat/completions
outputTokens = int(ct.Int())
} else if ot := usage.Get("output_tokens"); ot.Exists() {
outputTokens = int(ot.Int())
}
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
cachedTokens = int(ct.Int())
}
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(timings.Get("prompt_n").Int())
outputTokens = int(timings.Get("predicted_n").Int())
promptPerSecond = timings.Get("prompt_per_second").Float()
tokensPerSecond = timings.Get("predicted_per_second").Float()
durationMs = int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
return TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
}, nil
}
// decompressBody decompresses the body based on Content-Encoding header
func decompressBody(body []byte, encoding string) ([]byte, error) {
switch strings.ToLower(strings.TrimSpace(encoding)) {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
case "deflate":
reader := flate.NewReader(bytes.NewReader(body))
defer reader.Close()
return io.ReadAll(reader)
default:
return body, nil // Return as-is for unknown/no encoding
}
}
// responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
gin.ResponseWriter
body *bytes.Buffer
tee io.Writer
start time.Time
}
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
bodyBuffer := &bytes.Buffer{}
return &responseBodyCopier{
ResponseWriter: w,
body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer),
}
}
func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
}
// Single write operation that writes to both the response and buffer
return w.tee.Write(b)
}
func (w *responseBodyCopier) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseBodyCopier) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *responseBodyCopier) StartTime() time.Time {
return w.start
}
// filterAcceptEncoding filters the Accept-Encoding header to only include
// encodings we can decompress (gzip, deflate). This respects the client's
// preferences while ensuring we can parse response bodies for metrics.
func filterAcceptEncoding(acceptEncoding string) string {
if acceptEncoding == "" {
return ""
}
supported := map[string]bool{"gzip": true, "deflate": true}
var filtered []string
for _, part := range strings.Split(acceptEncoding, ",") {
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
if supported[strings.ToLower(encoding)] {
filtered = append(filtered, strings.TrimSpace(part))
}
}
return strings.Join(filtered, ", ")
}
+834
View File
@@ -0,0 +1,834 @@
package proxy
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/stretchr/testify/assert"
)
func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"})
}
metrics := mm.getMetrics()
assert.Equal(t, 5, len(metrics))
for i := 0; i < 5; i++ {
assert.Equal(t, i, metrics[i].ID)
}
})
t.Run("respects max metrics limit", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 3)
// Add 5 metrics
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{
Model: "model",
InputTokens: i,
})
}
metrics := mm.getMetrics()
assert.Equal(t, 3, len(metrics))
// Should keep the last 3 metrics (IDs 2, 3, 4)
assert.Equal(t, 2, metrics[0].ID)
assert.Equal(t, 3, metrics[1].ID)
assert.Equal(t, 4, metrics[2].ID)
})
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
receivedEvent := make(chan TokenMetricsEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) {
receivedEvent <- e
})
defer cancel()
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
select {
case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event")
}
})
}
func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns empty slice when no metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metrics := mm.getMetrics()
assert.NotNil(t, metrics)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"})
metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics()
// Verify we got copies
assert.Equal(t, 2, len(metrics1))
assert.Equal(t, 2, len(metrics2))
// Modify the returned slice shouldn't affect the original
metrics1[0].Model = "modified"
metrics3 := mm.getMetrics()
assert.Equal(t, "model1", metrics3[0].Model)
})
}
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
assert.NotNil(t, jsonData)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{
Model: "model1",
InputTokens: 100,
OutputTokens: 50,
TokensPerSecond: 25.5,
})
mm.addMetrics(TokenMetrics{
Model: "model2",
InputTokens: 200,
OutputTokens: 100,
TokensPerSecond: 30.0,
})
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 2, len(metrics))
assert.Equal(t, "model1", metrics[0].Model)
assert.Equal(t, "model2", metrics[1].Model)
})
}
func TestMetricsMonitor_WrapHandler(t *testing.T) {
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("successful request with timings data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0,
"cache_n": 20
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
})
t.Run("streaming request with SSE format", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Note: SSE format requires proper line breaks - each data line followed by blank line
responseBody := `data: {"choices":[{"text":"Hello"}]}
data: {"choices":[{"text":" World"}]}
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens)
})
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("empty response body records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("not valid json"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("next handler error is propagated", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
expectedErr := assert.AnError
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
return expectedErr
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"result": "ok"}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
}
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
t.Run("captures response body", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
testData := []byte("test response body")
n, err := copier.Write(testData)
assert.NoError(t, err)
assert.Equal(t, len(testData), n)
assert.Equal(t, testData, copier.body.Bytes())
assert.Equal(t, string(testData), rec.Body.String())
})
t.Run("sets start time on first write", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
assert.True(t, copier.StartTime().IsZero())
copier.Write([]byte("test"))
assert.False(t, copier.StartTime().IsZero())
})
t.Run("preserves headers", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.Header().Set("X-Test", "value")
assert.Equal(t, "value", rec.Header().Get("X-Test"))
})
t.Run("preserves status code", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.WriteHeader(http.StatusCreated)
// Gin's ResponseWriter tracks status internally
assert.Equal(t, http.StatusCreated, copier.Status())
})
}
func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000)
var wg sync.WaitGroup
numGoroutines := 10
metricsPerGoroutine := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{
Model: "test-model",
InputTokens: id*1000 + j,
OutputTokens: j,
})
}
}(i)
}
wg.Wait()
metrics := mm.getMetrics()
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
})
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 100)
done := make(chan bool)
// Writer goroutine
go func() {
for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"})
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
// Multiple reader goroutines
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 20; j++ {
_ = mm.getMetrics()
_, _ = mm.getMetricsJSON()
time.Sleep(2 * time.Millisecond)
}
}()
}
<-done
wg.Wait()
// Final check
metrics := mm.getMetrics()
assert.Equal(t, 50, len(metrics))
})
}
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
t.Run("prefers timings over usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Timings should take precedence over usage
responseBody := `{
"usage": {
"prompt_tokens": 50,
"completion_tokens": 25
},
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles missing cache_n in timings", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
})
}
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Metrics should be found in the last data line before [DONE]
responseBody := `data: {"choices":[{"text":"First"}]}
data: {"choices":[{"text":"Second"}]}
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `data: not json
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := ``
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
}
// Benchmark tests
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
t.Run("gzip encoded response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
// Compress with gzip
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
gzWriter.Write([]byte(responseBody))
gzWriter.Close()
compressedBody := buf.Bytes()
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
w.Write(compressedBody)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("deflate encoded response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
// Compress with deflate
var buf bytes.Buffer
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
flateWriter.Write([]byte(responseBody))
flateWriter.Close()
compressedBody := buf.Bytes()
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "deflate")
w.WriteHeader(http.StatusOK)
w.Write(compressedBody)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 200, metrics[0].InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens)
})
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Invalid compressed data
invalidData := []byte("this is not gzip data")
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
w.Write(invalidData)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Should not return error, just log warning
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "unknown-encoding")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 300, metrics[0].InputTokens)
assert.Equal(t, 100, metrics[0].OutputTokens)
})
}
+127
View File
@@ -0,0 +1,127 @@
package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type peerProxyMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type PeerProxy struct {
peers config.PeerDictionaryConfig
proxyMap map[string]*peerProxyMember
}
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
proxyMap := make(map[string]*peerProxyMember)
// Sort peer IDs for consistent iteration order
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
// Create a shared transport with reasonable timeouts for peer connections
// these can be tuned with feedback later
peerTransport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, // Connection timeout
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}
for _, peerID := range peerIDs {
peer := peers[peerID]
// Create reverse proxy for this peer
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
reverseProxy.Transport = peerTransport
// Wrap Director to set Host header for remote hosts (not localhost)
originalDirector := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
originalDirector(req)
// Ensure Host header matches target URL for remote proxying
req.Host = req.URL.Host
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerProxyMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
// Map each model to this peer's proxy
for _, modelID := range peer.Models {
if _, found := proxyMap[modelID]; found {
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
proxyMap[modelID] = pp
}
}
return &PeerProxy{
peers: peers,
proxyMap: proxyMap,
}, nil
}
func (p *PeerProxy) HasPeerModel(modelID string) bool {
_, found := p.proxyMap[modelID]
return found
}
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
return p.peers
}
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
pp, found := p.proxyMap[model_id]
if !found {
return fmt.Errorf("no peer proxy found for model %s", model_id)
}
// Inject API key if configured for this peer
if pp.apiKey != "" {
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
request.Header.Set("x-api-key", pp.apiKey)
}
pp.reverseProxy.ServeHTTP(writer, request)
return nil
}
+268
View File
@@ -0,0 +1,268 @@
package proxy
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.NotNil(t, pm)
assert.Empty(t, pm.proxyMap)
}
func TestNewPeerProxy_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 2)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.False(t, pm.HasPeerModel("model-c"))
}
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 4)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.True(t, pm.HasPeerModel("model-c"))
assert.True(t, pm.HasPeerModel("model-d"))
}
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
// should be mapped, and a warning should be logged
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
// Should only have one entry for the duplicate model
assert.Len(t, pm.proxyMap, 1)
assert.True(t, pm.HasPeerModel("duplicate-model"))
}
func TestHasPeerModel(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
Models: []string{"existing-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.True(t, pm.HasPeerModel("existing-model"))
assert.False(t, pm.HasPeerModel("non-existing-model"))
}
func TestProxyRequest_ModelNotFound(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("non-existing-model", w, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
}
func TestProxyRequest_Success(t *testing.T) {
// Create a test server to act as the peer
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "response from peer", w.Body.String())
}
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
}
func TestProxyRequest_NoApiKey(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "", // No API key
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Empty(t, receivedAuthHeader)
}
func TestProxyRequest_HostHeaderSet(t *testing.T) {
// Create a test server that checks the Host header
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The Host header should be set to the target URL's host
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
}
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
// Create a test server that returns SSE content type
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The X-Accel-Buffering header should be set to "no" for SSE
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
}
+392 -65
View File
@@ -2,19 +2,23 @@ package proxy
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type ProcessState string
@@ -37,11 +41,13 @@ const (
)
type Process struct {
ID string
config ModelConfig
cmd *exec.Cmd
ID string
config config.ModelConfig
cmd *exec.Cmd
reverseProxy *httputil.ReverseProxy
// PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc
// closed when command exits
@@ -53,12 +59,14 @@ type Process struct {
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
lastRequestHandledMutex sync.RWMutex
lastRequestHandled time.Time
stateMutex sync.RWMutex
state ProcessState
inFlightRequests sync.WaitGroup
inFlightRequests sync.WaitGroup
inFlightRequestsCount atomic.Int32
// used to block on multiple start() calls
waitStarting sync.WaitGroup
@@ -73,16 +81,35 @@ type Process struct {
failedStartCount int
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
}
// Setup the reverse proxy.
proxyURL, err := url.Parse(config.Proxy)
if err != nil {
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
}
var reverseProxy *httputil.ReverseProxy
if proxyURL != nil {
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
reverseProxy.ModifyResponse = func(resp *http.Response) error {
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
}
return &Process{
ID: ID,
config: config,
cmd: nil,
reverseProxy: reverseProxy,
cancelUpstream: nil,
processLogger: processLogger,
proxyLogger: proxyLogger,
@@ -105,6 +132,20 @@ func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
func (p *Process) setLastRequestHandled(t time.Time) {
p.lastRequestHandledMutex.Lock()
defer p.lastRequestHandledMutex.Unlock()
p.lastRequestHandled = t
}
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
func (p *Process) getLastRequestHandled() time.Time {
p.lastRequestHandledMutex.RLock()
defer p.lastRequestHandledMutex.RUnlock()
return p.lastRequestHandled
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -128,6 +169,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
}
p.state = newState
// Atomically increment waitStarting when entering StateStarting
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
if newState == StateStarting {
p.waitStarting.Add(1)
}
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil
@@ -156,6 +204,15 @@ func (p *Process) CurrentState() ProcessState {
return p.state
}
// forceState forces the process state to the new state with mutex protection.
// This should only be used in exceptional cases where the normal state transition
// validation via swapState() cannot be used.
func (p *Process) forceState(newState ProcessState) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.state = newState
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
@@ -189,7 +246,7 @@ func (p *Process) start() error {
}
}
p.waitStarting.Add(1)
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
@@ -199,8 +256,12 @@ func (p *Process) start() error {
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout
setProcAttributes(p.cmd)
p.cmdMutex.Lock()
p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()
p.failedStartCount++ // this will be reset to zero when the process has successfully started
@@ -210,7 +271,7 @@ func (p *Process) start() error {
// Set process state to failed
if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state
p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr,
@@ -283,10 +344,12 @@ func (p *Process) start() error {
return
}
// wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
// skip the TTL check if there are inflight requests
if p.inFlightRequestsCount.Load() != 0 {
continue
}
if time.Since(p.lastRequestHandled) > maxDuration {
if time.Since(p.getLastRequestHandled()) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop()
return
@@ -342,7 +405,7 @@ func (p *Process) Shutdown() {
p.stopCommand()
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown
p.forceState(StateShutdown)
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -351,20 +414,38 @@ func (p *Process) stopCommand() {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
// free the buffer in processLogger so the memory can be recovered
p.processLogger.Clear()
}()
if p.cancelUpstream == nil {
p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
p.cmdMutex.RUnlock()
if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return
}
p.cancelUpstream()
<-p.cmdWaitChan
cancelUpstream()
<-cmdWaitChan
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
client := &http.Client{
Timeout: 500 * time.Millisecond,
// wait a short time for a tcp connection to be established
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 500 * time.Millisecond,
}).DialContext,
},
// give a long time to respond to the health check endpoint
// after the connection is established. See issue: 276
Timeout: 5000 * time.Millisecond,
}
req, err := http.NewRequest("GET", healthURL, nil)
@@ -387,6 +468,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if p.reverseProxy == nil {
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
return
}
requestBeginTime := time.Now()
var startDuration time.Duration
@@ -406,68 +493,75 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
}
p.inFlightRequests.Add(1)
p.inFlightRequestsCount.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.setLastRequestHandled(time.Now())
p.inFlightRequestsCount.Add(-1)
p.inFlightRequests.Done()
}()
// for #366
// - extract streaming param from request context, should have been set by proxymanager
var srw *statusResponseWriter
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
// start the process on demand
if p.CurrentState() != StateReady {
// start a goroutine to stream loading status messages into the response writer
// add a sync so the streaming client only runs when the goroutine has exited
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
// PR #417 (no support for anthropic v1/messages yet)
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
srw = newStatusResponseWriter(p, w)
go srw.statusUpdates(swapCtx)
} else {
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
}
beginStartTime := time.Now()
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway)
cancelLoadCtx()
if srw != nil {
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
// before closing the connection. Without this, the connection would close before
// the goroutine can write its cleanup messages, causing incomplete SSE output.
srw.waitForCompletion(100 * time.Millisecond)
} else {
http.Error(w, errstr, http.StatusBadGateway)
}
return
}
startDuration = time.Since(beginStartTime)
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
// should trigger srw to stop sending loading events ...
cancelLoadCtx()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
// recover from http.ErrAbortHandler panics that can occur when the client
// disconnects before the response is sent
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
} else {
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}()
if srw != nil {
// Wait for the goroutine to finish writing its final messages
const completionTimeout = 1 * time.Second
if !srw.waitForCompletion(completionTimeout) {
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
}
p.reverseProxy.ServeHTTP(srw, r)
} else {
p.reverseProxy.ServeHTTP(w, r)
}
totalTime := time.Since(requestBeginTime)
@@ -503,13 +597,16 @@ func (p *Process) waitForCmd() {
case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped
p.forceState(StateStopped)
}
default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state
p.forceState(StateStopped) // force it to be in this state
}
p.cmdMutex.Lock()
close(p.cmdWaitChan)
p.cmdMutex.Unlock()
}
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
@@ -524,7 +621,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return err
@@ -535,6 +632,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
setProcAttributes(stopCmd)
stopCmd.Env = p.cmd.Env
if err := stopCmd.Run(); err != nil {
@@ -550,3 +648,232 @@ func (p *Process) cmdStopUpstreamProcess() error {
return nil
}
// Logger returns the logger for this process.
func (p *Process) Logger() *LogMonitor {
return p.processLogger
}
var loadingRemarks = []string{
"Still faster than your last standup meeting...",
"Reticulating splines...",
"Waking up the hamsters...",
"Teaching the model manners...",
"Convincing the GPU to participate...",
"Loading weights (they're heavy)...",
"Herding electrons...",
"Compiling excuses for the delay...",
"Downloading more RAM...",
"Asking the model nicely to boot up...",
"Bribing CUDA with cookies...",
"Still loading (blame VRAM)...",
"The model is fashionably late...",
"Warming up those tensors...",
"Making the neural net do push-ups...",
"Your patience is appreciated (really)...",
"Almost there (probably)...",
"Loading like it's 1999...",
"The model forgot where it put its keys...",
"Quantum tunneling through layers...",
"Negotiating with the PCIe bus...",
"Defrosting frozen parameters...",
"Teaching attention heads to focus...",
"Running the matrix (slowly)...",
"Untangling transformer blocks...",
"Calibrating the flux capacitor...",
"Spinning up the probability wheels...",
"Waiting for the GPU to wake from its nap...",
"Converting caffeine to compute...",
"Allocating virtual patience...",
"Performing arcane CUDA rituals...",
"The model is stuck in traffic...",
"Inflating embeddings...",
"Summoning computational demons...",
"Pleading with the OOM killer...",
"Calculating the meaning of life (still at 42)...",
"Training the training wheels...",
"Optimizing the optimizer...",
"Bootstrapping the bootstrapper...",
"Loading loading screen...",
"Processing processing logs...",
"Buffering buffer overflow jokes...",
"The model hit snooze...",
"Debugging the debugger...",
"Compiling the compiler...",
"Parsing the parser (meta)...",
"Tokenizing tokens...",
"Encoding the encoder...",
"Hashing hash browns...",
"Forking spoons (not forks)...",
"The model is contemplating existence...",
"Transcending dimensional barriers...",
"Invoking elder tensor gods...",
"Unfurling probability clouds...",
"Synchronizing parallel universes...",
"The GPU is having second thoughts...",
"Recalibrating reality matrices...",
"Time is an illusion, loading doubly so...",
"Convincing bits to flip themselves...",
"The model is reading its own documentation...",
}
type statusResponseWriter struct {
hasWritten bool
writer http.ResponseWriter
process *Process
wg sync.WaitGroup // Track goroutine completion
start time.Time
}
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
s := &statusResponseWriter{
writer: w,
process: p,
start: time.Now(),
}
s.Header().Set("Content-Type", "text/event-stream") // SSE
s.Header().Set("Cache-Control", "no-cache") // no-cache
s.Header().Set("Connection", "keep-alive") // keep-alive
s.WriteHeader(http.StatusOK) // send status code 200
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
return s
}
// statusUpdates sends status updates to the client while the model is loading
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
s.wg.Add(1)
defer s.wg.Done()
// Recover from panics caused by client disconnection
// Note: recover() only works within the same goroutine, so we need it here
defer func() {
if r := recover(); r != nil {
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
}
}()
defer func() {
duration := time.Since(s.start)
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
// Create a shuffled copy of loadingRemarks
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
// Pick a random duration to send a remark
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Now()
ticker := time.NewTicker(time.Second)
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if s.process.CurrentState() == StateReady {
return
}
// Check if it's time for a snarky remark
if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendLine(fmt.Sprintf("\n%s", remark))
lastRemarkTime = time.Now()
// Pick a new random duration for the next remark
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
// waitForCompletion waits for the statusUpdates goroutine to finish
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
func (s *statusResponseWriter) sendLine(line string) {
s.sendData(line + "\n")
}
func (s *statusResponseWriter) sendData(data string) {
// Create the proper SSE JSON structure
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
return
}
// Write SSE formatted data, panic if not able to write
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
if err != nil {
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
}
s.Flush()
}
func (s *statusResponseWriter) Header() http.Header {
return s.writer.Header()
}
func (s *statusResponseWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *statusResponseWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
func (s *statusResponseWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+86 -12
View File
@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
)
@@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
config := ModelConfig{
config := config.ModelConfig{
Cmd: "nonexistent-command",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
@@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
config := config.ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
@@ -402,7 +403,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
binaryPath := getSimpleResponderPath()
port := getTestPort()
config := ModelConfig{
conf := config.ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
@@ -410,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
defer process.Stop()
// reduce to make testing go faster
@@ -435,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else {
assert.Contains(t, w.Body.String(), "unexpected EOF")
// Upstream may be killed mid-response.
// Assert an incomplete or partial response.
assert.NotEqual(t, "12345", w.Body.String())
}
close(waitChan)
@@ -450,15 +453,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
}
func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
conf := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
conf.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
conf.CmdStop = "kill -TERM ${PID}"
}
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
@@ -470,15 +473,15 @@ func TestProcess_StopCmd(t *testing.T) {
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
expectedMessage := "test_env_not_emptied"
config := getTestSimpleResponderConfig(expectedMessage)
conf := getTestSimpleResponderConfig(expectedMessage)
// ensure that the the default config does not blank out the inherited environment
configWEnv := config
configWEnv := conf
// ensure the additiona variables are appended to the process' environment
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
process1.start()
@@ -491,3 +494,74 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
}
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
// handled appropriately.
//
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
// can't copy the body. This can be caused by a client disconnecting before the full
// response is sent from some reason.
//
// bug: https://github.com/mostlygeek/llama-swap/issues/362
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
defer func() {
if r := recover(); r != nil {
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
}
}()
expectedMessage := "panic_test"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
defer process.Stop()
// Start the process
err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())
// Create a custom ResponseWriter that simulates a client disconnect
// by panicking when Write is called after headers are sent
panicWriter := &panicOnWriteResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
shouldPanic: true,
}
// Make a request that will trigger the panic
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
// ProxyRequest should catch and handle this panic gracefully.
process.ProxyRequest(panicWriter, req)
// If we get here, the panic was properly recovered in ProxyRequest
// The process should still be in a ready state
assert.Equal(t, StateReady, process.CurrentState())
}
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
// to simulate a client disconnect after headers are sent
// used by: TestProcess_ReverseProxyPanicIsHandled
type panicOnWriteResponseWriter struct {
*httptest.ResponseRecorder
shouldPanic bool
headerWritten bool
}
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
w.headerWritten = true
w.ResponseRecorder.WriteHeader(statusCode)
}
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
if w.shouldPanic && w.headerWritten {
// Simulate the panic that httputil.ReverseProxy throws
panic(http.ErrAbortHandler)
}
return w.ResponseRecorder.Write(b)
}
+12
View File
@@ -0,0 +1,12 @@
//go:build !windows
package proxy
import (
"os/exec"
)
// setProcAttributes sets platform-specific process attributes
func setProcAttributes(cmd *exec.Cmd) {
// No-op on Unix systems
}
+16
View File
@@ -0,0 +1,16 @@
//go:build windows
package proxy
import (
"os/exec"
"syscall"
)
// setProcAttributes sets platform-specific process attributes
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
}
}
+46 -3
View File
@@ -5,12 +5,14 @@ import (
"net/http"
"slices"
"sync"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type ProcessGroup struct {
sync.Mutex
config Config
config config.Config
id string
swap bool
exclusive bool
@@ -24,7 +26,7 @@ type ProcessGroup struct {
lastUsedProcess string
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
@@ -44,7 +46,8 @@ func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstream
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
@@ -60,10 +63,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
if pg.swap {
pg.Lock()
if pg.lastUsedProcess != modelID {
// is there something already running?
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
}
// wait for the request to the new model to be fully handled
// and prevent race conditions see issue #277
pg.processes[modelID].ProxyRequest(writer, request)
pg.lastUsedProcess = modelID
// short circuit and exit
pg.Unlock()
return nil
}
pg.Unlock()
}
@@ -76,6 +89,36 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
if pg.HasMember(modelName) {
return pg.processes[modelName], true
}
return nil, false
}
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
pg.Lock()
process, exists := pg.processes[modelID]
if !exists {
pg.Unlock()
return fmt.Errorf("process not found for %s", modelID)
}
if pg.lastUsedProcess == modelID {
pg.lastUsedProcess = ""
}
pg.Unlock()
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
return nil
}
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
+39 -20
View File
@@ -4,21 +4,23 @@ import (
"bytes"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
@@ -33,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
@@ -44,32 +46,49 @@ func TestProcessGroup_HasMember(t *testing.T) {
assert.False(t, pg.HasMember("model3"))
}
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
// use the same listening so if a model is already running, it will fail
// this is a way to test that swap isolation is working
// properly when there are parallel requests made at the
// same time.
"model1": getTestSimpleResponderConfigPort("model1", 9832),
"model2": getTestSimpleResponderConfigPort("model2", 9832),
"model3": getTestSimpleResponderConfigPort("model3", 9832),
"model4": getTestSimpleResponderConfigPort("model4", 9832),
"model5": getTestSimpleResponderConfigPort("model5", 9832),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2", "model3", "model4", "model5"},
},
},
})
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
tests := []string{"model1", "model2"}
tests := []string{"model1", "model2", "model3", "model4", "model5"}
var wg sync.WaitGroup
wg.Add(len(tests))
for _, modelName := range tests {
t.Run(modelName, func(t *testing.T) {
reqBody := `{"x", "y"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
go func(modelName string) {
defer wg.Done()
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
// make sure only one process is in the running state
count := 0
for _, process := range pg.processes {
if process.CurrentState() == StateReady {
count++
}
}
assert.Equal(t, 1, count)
})
}(modelName)
}
wg.Wait()
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
+380 -98
View File
@@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -24,10 +25,12 @@ const (
PROFILE_SPLIT_CHAR = ":"
)
type proxyCtxKey string
type ProxyManager struct {
sync.Mutex
config Config
config config.Config
ginEngine *gin.Engine
// logging
@@ -35,26 +38,54 @@ type ProxyManager struct {
upstreamLogger *LogMonitor
muxLogger *LogMonitor
metricsMonitor *MetricsMonitor
metricsMonitor *metricsMonitor
processGroups map[string]*ProcessGroup
// shutdown signaling
shutdownCtx context.Context
shutdownCancel context.CancelFunc
// version info
buildDate string
commit string
version string
// peer proxy see: #296, #433
peerProxy *PeerProxy
}
func New(config Config) *ProxyManager {
func New(proxyConfig config.Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
switch proxyConfig.LogToStdout {
case config.LogToStdoutNone:
muxLogger = NewLogMonitorWriter(io.Discard)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(io.Discard)
case config.LogToStdoutBoth:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(muxLogger)
case config.LogToStdoutUpstream:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(io.Discard)
default:
// same as config.LogToStdoutProxy
// helpful because some old tests create a config.Config directly and it
// may not have LogToStdout set explicitly
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(muxLogger)
}
if proxyConfig.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
@@ -72,53 +103,103 @@ func New(config Config) *ProxyManager {
upstreamLogger.SetLogLevel(LevelInfo)
}
// see: https://go.dev/src/time/format.go
timeFormats := map[string]string{
"ansic": time.ANSIC,
"unixdate": time.UnixDate,
"rubydate": time.RubyDate,
"rfc822": time.RFC822,
"rfc822z": time.RFC822Z,
"rfc850": time.RFC850,
"rfc1123": time.RFC1123,
"rfc1123z": time.RFC1123Z,
"rfc3339": time.RFC3339,
"rfc3339nano": time.RFC3339Nano,
"kitchen": time.Kitchen,
"stamp": time.Stamp,
"stampmilli": time.StampMilli,
"stampmicro": time.StampMicro,
"stampnano": time.StampNano,
}
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
proxyLogger.SetLogTimeFormat(timeFormat)
upstreamLogger.SetLogTimeFormat(timeFormat)
}
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int
if proxyConfig.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback
} else {
maxMetrics = proxyConfig.MetricsMaxInMemory
}
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
if err != nil {
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
peerProxy = nil
}
pm := &ProxyManager{
config: config,
config: proxyConfig,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
muxLogger: muxLogger,
upstreamLogger: upstreamLogger,
metricsMonitor: NewMetricsMonitor(&config),
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
processGroups: make(map[string]*ProcessGroup),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
buildDate: "unknown",
commit: "abcd1234",
version: "0",
peerProxy: peerProxy,
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
// run any startup hooks
if len(config.Hooks.OnStartup.Preload) > 0 {
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
// do it in the background, don't block startup -- not sure if good idea yet
go func() {
discardWriter := &DiscardWriter{}
for _, realModelName := range config.Hooks.OnStartup.Preload {
proxyLogger.Infof("Preloading model: %s", realModelName)
processGroup, _, err := pm.swapProcessGroup(realModelName)
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
modelID, ok := proxyConfig.RealModelName(preloadModelName)
if !ok {
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
continue
}
proxyLogger.Infof("Preloading model: %s", modelID)
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
ModelName: modelID,
Success: false,
})
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
continue
} else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(realModelName, discardWriter, req)
processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
ModelName: modelID,
Success: true,
})
}
@@ -130,7 +211,15 @@ func New(config Config) *ProxyManager {
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) {
// don't log the Wake on Lan proxy health check
if c.Request.URL.Path == "/wol-health" {
c.Next()
return
}
// Start timer
start := time.Now()
@@ -184,35 +273,39 @@ func (pm *ProxyManager) setupGinEngine() {
c.Next()
})
mm := MetricsMiddleware(pm)
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
// Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
/**
* User Interface Endpoints
@@ -224,14 +317,18 @@ func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui/models")
})
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
// see cmd/wol-proxy/wol-proxy.go, not logged
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
@@ -320,16 +417,10 @@ func (pm *ProxyManager) Shutdown() {
pm.shutdownCancel()
}
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
}
if processGroup.exclusive {
@@ -341,20 +432,16 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
}
}
return processGroup, realModelName, nil
return processGroup, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := make([]gin.H, 0, len(pm.config.Models))
createdTime := time.Now().Unix()
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
record := gin.H{
"id": id,
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
@@ -367,7 +454,47 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
record["description"] = desc
}
data = append(data, record)
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, newRecord(id, modelConfig))
// Include aliases
if pm.config.IncludeAliasesInList {
for _, alias := range modelConfig.Aliases {
if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias, modelConfig))
}
}
}
}
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
// add peer models
for _, modelID := range peer.Models {
// Skip unlisted models if not showing them
record := newRecord(modelID, config.ModelConfig{
Name: fmt.Sprintf("%s: %s", peerID, modelID),
Metadata: map[string]any{
"peerID": peerID,
},
})
data = append(data, record)
}
}
}
// Sort by the "id" key
@@ -389,26 +516,87 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
})
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id")
// findModelInPath searches for a valid model name in a path with slashes.
// It iteratively builds up path segments until it finds a matching model.
// Returns: (searchModelName, realModelName, remainingPath, found)
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
searchModelName := ""
if requestedModel == "" {
for i, part := range parts {
if part == "" {
continue
}
if searchModelName == "" {
searchModelName = part
} else {
searchModelName = searchModelName + "/" + part
}
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
if !modelFound {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
// This ensures relative URLs in upstream responses resolve correctly and
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
// HTTP method (301 would downgrade to GET).
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
originalPath := c.Request.URL.Path
c.Request.URL.Path = remainingPath
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return
}
}
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
@@ -421,41 +609,54 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -465,10 +666,24 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
}
@@ -486,7 +701,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
modelID, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
@@ -504,7 +725,7 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
fieldValue = useModelName
@@ -575,9 +796,9 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
if err := processGroup.ProxyRequest(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, modelID)
return
}
}
@@ -592,6 +813,59 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}
}
// apiKeyAuth returns a middleware that validates API keys if configured.
// Returns a pass-through handler if no API keys are configured.
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
if len(pm.config.RequiredAPIKeys) == 0 {
return func(c *gin.Context) { c.Next() }
}
return func(c *gin.Context) {
xApiKey := c.GetHeader("x-api-key")
var bearerKey string
if auth := c.GetHeader("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ")
}
}
// If both headers present, they must match
if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey {
pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match")
c.Abort()
return
}
// Use x-api-key first, then Authorization
providedKey := xApiKey
if providedKey == "" {
providedKey = bearerKey
}
// Validate key
valid := false
for _, key := range pm.config.RequiredAPIKeys {
if providedKey == key {
valid = true
break
}
}
if !valid {
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
c.Abort()
return
}
// Strip auth headers to prevent leakage to upstream
c.Request.Header.Del("Authorization")
c.Request.Header.Del("x-api-key")
c.Next()
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
@@ -628,3 +902,11 @@ func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
}
return nil
}
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
pm.Lock()
defer pm.Unlock()
pm.buildDate = buildDate
pm.commit = commit
pm.version = version
}
+53 -3
View File
@@ -3,8 +3,10 @@ package proxy
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"strings"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
@@ -16,15 +18,19 @@ type Model struct {
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"`
}
func addApiHandlers(pm *ProxyManager) {
// Add API endpoints for React to consume
apiGroup := pm.ginEngine.Group("/api")
// Protected with API key authentication
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
{
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
apiGroup.GET("/version", pm.apiGetVersion)
}
}
@@ -78,6 +84,18 @@ func (pm *ProxyManager) getModelStatus() []Model {
})
}
// Iterate over the peer models
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
for _, modelID := range peer.Models {
models = append(models, Model{
Id: modelID,
PeerID: peerID,
})
}
}
}
return models
}
@@ -100,6 +118,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
// prevent nginx from buffering SSE
c.Header("X-Accel-Buffering", "no")
sendBuffer := make(chan messageEnvelope, 25)
ctx, cancel := context.WithCancel(c.Request.Context())
@@ -175,7 +195,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
sendMetrics(pm.metricsMonitor.GetMetrics())
sendMetrics(pm.metricsMonitor.getMetrics())
for {
select {
@@ -193,10 +213,40 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
jsonData, err := pm.metricsMonitor.getMetricsJSON()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return
}
c.Data(http.StatusOK, "application/json", jsonData)
}
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
return
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
}
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
return
} else {
c.String(http.StatusOK, "OK")
}
}
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
c.JSON(http.StatusOK, map[string]string{
"version": pm.version,
"commit": pm.commit,
"build_date": pm.buildDate,
})
}
+22 -13
View File
@@ -28,8 +28,10 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.Header("Transfer-Encoding", "chunked")
c.Header("X-Content-Type-Options", "nosniff")
// prevent nginx from buffering streamed logs
c.Header("X-Accel-Buffering", "no")
logMonitorId := c.Param("logMonitorID")
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
@@ -81,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
switch logMonitorId {
case "":
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return pm.muxLogger, nil
case "proxy":
return pm.proxyLogger, nil
case "upstream":
return pm.upstreamLogger, nil
default:
// search for a models specific logger using findModelInPath
// to handle model names with slashes (e.g., "author/model")
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
for _, group := range pm.processGroups {
if process, found := group.GetMember(name); found {
return process.Logger(), nil
}
}
}
return logger, nil
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
}
}
+769 -141
View File
File diff suppressed because it is too large Load Diff
+172 -92
View File
File diff suppressed because it is too large Load Diff
+4 -5
View File
@@ -4,23 +4,21 @@
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"start": "vite",
"build": "tsc -b && vite build --emptyOutDir",
"lint": "eslint .",
"preview": "vite preview"
},
"dependencies": {
"@tailwindcss/vite": "^4.1.8",
"@tanstack/react-query": "^5.80.6",
"react": "^19.1.0",
"react-dom": "^19.1.0",
"react-icons": "^5.5.0",
"react-resizable-panels": "^3.0.4",
"react-router-dom": "^7.6.2",
"tailwindcss": "^4.1.8"
"react-router-dom": "^7.6.2"
},
"devDependencies": {
"@eslint/js": "^9.25.0",
"@tailwindcss/vite": "^4.1.8",
"@types/react": "^19.1.2",
"@types/react-dom": "^19.1.2",
"@vitejs/plugin-react": "^4.4.1",
@@ -28,6 +26,7 @@
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-react-refresh": "^0.4.19",
"globals": "^16.0.0",
"tailwindcss": "^4.1.8",
"typescript": "~5.8.3",
"typescript-eslint": "^8.30.1",
"vite": "^6.3.5"
+7 -49
View File
@@ -1,21 +1,14 @@
import { useEffect, useCallback } from "react";
import { BrowserRouter as Router, Routes, Route, Navigate, NavLink } from "react-router-dom";
import { useTheme } from "./contexts/ThemeProvider";
import { useEffect } from "react";
import { Navigate, Route, BrowserRouter as Router, Routes } from "react-router-dom";
import { Header } from "./components/Header";
import { useAPI } from "./contexts/APIProvider";
import { useTheme } from "./contexts/ThemeProvider";
import ActivityPage from "./pages/Activity";
import LogViewerPage from "./pages/LogViewer";
import ModelPage from "./pages/Models";
import ActivityPage from "./pages/Activity";
import ConnectionStatusIcon from "./components/ConnectionStatus";
import { RiSunFill, RiMoonFill } from "react-icons/ri";
function App() {
const { isNarrow, toggleTheme, isDarkMode, appTitle, setAppTitle, setConnectionState } = useTheme();
const handleTitleChange = useCallback(
(newTitle: string) => {
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
},
[setAppTitle]
);
const { setConnectionState } = useTheme();
const { connectionStatus } = useAPI();
@@ -27,42 +20,7 @@ function App() {
return (
<Router basename="/ui/">
<div className="flex flex-col h-screen">
<nav className="bg-surface border-b border-border p-2 h-[75px]">
<div className="flex items-center justify-between mx-auto px-4 h-full">
{!isNarrow && (
<h1
contentEditable
suppressContentEditableWarning
className="flex items-center p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded px-1"
onBlur={(e) => handleTitleChange(e.currentTarget.textContent || "(set title)")}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
handleTitleChange(e.currentTarget.textContent || "(set title)");
e.currentTarget.blur();
}
}}
>
{appTitle}
</h1>
)}
<div className="flex items-center space-x-4">
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Logs
</NavLink>
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Models
</NavLink>
<NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Activity
</NavLink>
<button className="" onClick={toggleTheme}>
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
</button>
<ConnectionStatusIcon />
</div>
</div>
</nav>
<Header />
<main className="flex-1 overflow-auto p-4">
<Routes>
+4 -4
View File
@@ -2,14 +2,14 @@ import { useAPI } from "../contexts/APIProvider";
import { useMemo } from "react";
const ConnectionStatusIcon = () => {
const { connectionStatus } = useAPI();
const { connectionStatus, versionInfo } = useAPI();
const eventStatusColor = useMemo(() => {
switch (connectionStatus) {
case "connected":
return "bg-green-500";
return "bg-emerald-500";
case "connecting":
return "bg-yellow-500";
return "bg-amber-500";
case "disconnected":
default:
return "bg-red-500";
@@ -17,7 +17,7 @@ const ConnectionStatusIcon = () => {
}, [connectionStatus]);
return (
<div className="flex items-center" title={`event stream: ${connectionStatus}`}>
<div className="flex items-center" title={`Event Stream: ${connectionStatus ?? 'unknown'}\nAPI Version: ${versionInfo?.version ?? 'unknown'}\nCommit Hash: ${versionInfo?.commit?.substring(0,7) ?? 'unknown'}\nBuild Date: ${versionInfo?.build_date ?? 'unknown'}`}>
<span className={`inline-block w-3 h-3 rounded-full ${eventStatusColor} mr-2`}></span>
</div>
);
+56
View File
@@ -0,0 +1,56 @@
import { useCallback } from "react";
import { RiMoonFill, RiSunFill } from "react-icons/ri";
import { NavLink, type NavLinkRenderProps } from "react-router-dom";
import { useTheme } from "../contexts/ThemeProvider";
import ConnectionStatusIcon from "./ConnectionStatus";
export function Header() {
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
const handleTitleChange = useCallback(
(newTitle: string) => {
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
},
[setAppTitle]
);
const navLinkClass = ({ isActive }: NavLinkRenderProps) =>
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
return (
<header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
{screenWidth !== "xs" && screenWidth !== "sm" && (
<h1
contentEditable
suppressContentEditableWarning
className="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
onBlur={(e) => handleTitleChange(e.currentTarget.textContent || "(set title)")}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
handleTitleChange(e.currentTarget.textContent || "(set title)");
e.currentTarget.blur();
}
}}
>
{appTitle}
</h1>
)}
<menu className="flex items-center gap-4">
<NavLink to="/" className={navLinkClass} type="button">
Logs
</NavLink>
<NavLink to="/models" className={navLinkClass} type="button">
Models
</NavLink>
<NavLink to="/activity" className={navLinkClass} type="button">
Activity
</NavLink>
<button className="" onClick={toggleTheme}>
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
</button>
<ConnectionStatusIcon />
</menu>
</header>
);
}
+79 -13
View File
@@ -1,4 +1,4 @@
import { useRef, createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
import { createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
import type { ConnectionState } from "../lib/types";
type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
@@ -10,24 +10,28 @@ export interface Model {
name: string;
description: string;
unlisted: boolean;
peerID: string;
}
interface APIProviderType {
models: Model[];
listModels: () => Promise<Model[]>;
unloadAllModels: () => Promise<void>;
unloadSingleModel: (model: string) => Promise<void>;
loadModel: (model: string) => Promise<void>;
enableAPIEvents: (enabled: boolean) => void;
proxyLogs: string;
upstreamLogs: string;
metrics: Metrics[];
connectionStatus: ConnectionState;
versionInfo: VersionInfo;
}
interface Metrics {
id: number;
timestamp: string;
model: string;
cache_tokens: number;
input_tokens: number;
output_tokens: number;
prompt_per_second: number;
@@ -39,23 +43,37 @@ interface LogData {
source: "upstream" | "proxy";
data: string;
}
interface APIEventEnvelope {
type: "modelStatus" | "logData" | "metrics";
data: string;
}
interface VersionInfo {
build_date: string;
commit: string;
version: string;
}
const APIContext = createContext<APIProviderType | undefined>(undefined);
type APIProviderProps = {
children: ReactNode;
autoStartAPIEvents?: boolean;
};
let apiEventSource: EventSource | null = null;
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
const [proxyLogs, setProxyLogs] = useState("");
const [upstreamLogs, setUpstreamLogs] = useState("");
const [metrics, setMetrics] = useState<Metrics[]>([]);
const [connectionStatus, setConnectionState] = useState<ConnectionState>("disconnected");
const apiEventSource = useRef<EventSource | null>(null);
const [versionInfo, setVersionInfo] = useState<VersionInfo>({
build_date: "unknown",
commit: "unknown",
version: "unknown",
});
//const apiEventSource = useRef<EventSource | null>(null);
const [models, setModels] = useState<Model[]>([]);
@@ -68,8 +86,8 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
const enableAPIEvents = useCallback((enabled: boolean) => {
if (!enabled) {
apiEventSource.current?.close();
apiEventSource.current = null;
apiEventSource?.close();
apiEventSource = null;
setMetrics([]);
return;
}
@@ -78,22 +96,22 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
const initialDelay = 1000; // 1 second
const connect = () => {
apiEventSource.current = null;
const eventSource = new EventSource("/api/events");
apiEventSource?.close();
apiEventSource = new EventSource("/api/events");
setConnectionState("connecting");
eventSource.onopen = () => {
apiEventSource.onopen = () => {
// clear everything out on connect to keep things in sync
setProxyLogs("");
setUpstreamLogs("");
setMetrics([]); // clear metrics on reconnect
setModels([]); // clear models on reconnect
apiEventSource.current = eventSource;
retryCount = 0;
setConnectionState("connected");
};
eventSource.onmessage = (e: MessageEvent) => {
apiEventSource.onmessage = (e: MessageEvent) => {
try {
const message = JSON.parse(e.data) as APIEventEnvelope;
switch (message.type) {
@@ -136,8 +154,8 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
}
};
eventSource.onerror = () => {
eventSource.close();
apiEventSource.onerror = () => {
apiEventSource?.close();
retryCount++;
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
setConnectionState("disconnected");
@@ -148,6 +166,26 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
connect();
}, []);
useEffect(() => {
// fetch version
const fetchVersion = async () => {
try {
const response = await fetch("/api/version");
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data: VersionInfo = await response.json();
setVersionInfo(data);
} catch (error) {
console.error(error);
}
};
if (connectionStatus === "connected") {
fetchVersion();
}
}, [connectionStatus]);
useEffect(() => {
if (autoStartAPIEvents) {
enableAPIEvents(true);
@@ -174,7 +212,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
const unloadAllModels = useCallback(async () => {
try {
const response = await fetch(`/api/models/unload/`, {
const response = await fetch(`/api/models/unload`, {
method: "POST",
});
if (!response.ok) {
@@ -186,6 +224,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
}
}, []);
const unloadSingleModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/api/models/unload/${model}`, {
method: "POST",
});
if (!response.ok) {
throw new Error(`Failed to unload model: ${response.status}`);
}
} catch (error) {
console.error("Failed to unload model", model, error);
throw error;
}
}, []);
const loadModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/upstream/${model}/`, {
@@ -205,14 +257,28 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
models,
listModels,
unloadAllModels,
unloadSingleModel,
loadModel,
enableAPIEvents,
proxyLogs,
upstreamLogs,
metrics,
connectionStatus,
versionInfo,
}),
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
[
models,
listModels,
unloadAllModels,
unloadSingleModel,
loadModel,
enableAPIEvents,
proxyLogs,
upstreamLogs,
metrics,
connectionStatus,
versionInfo,
]
);
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
+10 -2
View File
@@ -93,6 +93,14 @@
@apply px-4;
}
/* Tables */
table th {
@apply p-2 font-semibold;
}
table td {
@apply p-2;
}
/* Navigation Header */
.navlink {
@@ -122,7 +130,7 @@
/* Status Badges */
.status {
@apply inline-block px-2 py-1 text-xs font-medium rounded-full;
@apply inline-block px-2 py-1 text-xs font-medium rounded-lg;
}
.status--ready {
@@ -140,7 +148,7 @@
/* Buttons */
.btn {
@apply bg-surface p-2 px-4 text-sm rounded-full border border-2 transition-colors duration-200 border-btn-border;
@apply bg-surface py-2 px-4 text-sm rounded-md border transition-colors duration-200 border-btn-border;
}
.btn:hover {
+82 -28
View File
@@ -1,10 +1,6 @@
import { useMemo } from "react";
import { useAPI } from "../contexts/APIProvider";
const formatTimestamp = (timestamp: string): string => {
return new Date(timestamp).toLocaleString();
};
const formatSpeed = (speed: number): string => {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
};
@@ -13,6 +9,33 @@ const formatDuration = (ms: number): string => {
return (ms / 1000).toFixed(2) + "s";
};
const formatRelativeTime = (timestamp: string): string => {
const now = new Date();
const date = new Date(timestamp);
const diffInSeconds = Math.floor((now.getTime() - date.getTime()) / 1000);
// Handle future dates by returning "just now"
if (diffInSeconds < 5) {
return "now";
}
if (diffInSeconds < 60) {
return `${diffInSeconds}s ago`;
}
const diffInMinutes = Math.floor(diffInSeconds / 60);
if (diffInMinutes < 60) {
return `${diffInMinutes}m ago`;
}
const diffInHours = Math.floor(diffInMinutes / 60);
if (diffInHours < 24) {
return `${diffInHours}h ago`;
}
return "a while ago";
};
const ActivityPage = () => {
const { metrics } = useAPI();
const sortedMetrics = useMemo(() => {
@@ -20,39 +43,46 @@ const ActivityPage = () => {
}, [metrics]);
return (
<div className="p-6">
<h1 className="text-2xl font-bold mb-4">Activity</h1>
<div className="p-2">
<h1 className="text-2xl font-bold">Activity</h1>
{metrics.length === 0 ? (
{metrics.length === 0 && (
<div className="text-center py-8">
<p className="text-gray-600">No metrics data available</p>
</div>
) : (
<div className="overflow-x-auto">
)}
{metrics.length > 0 && (
<div className="card overflow-auto">
<table className="min-w-full divide-y">
<thead>
<tr>
<th className="px-4 py-3 text-left text-xs font-medium uppercase tracking-wider">Id</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Prompt Processing</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
<thead className="border-gray-200 dark:border-white/10">
<tr className="text-left text-xs uppercase tracking-wider">
<th className="px-6 py-3">ID</th>
<th className="px-6 py-3">Time</th>
<th className="px-6 py-3">Model</th>
<th className="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" />
</th>
<th className="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" />
</th>
<th className="px-6 py-3">Generated</th>
<th className="px-6 py-3">Prompt Processing</th>
<th className="px-6 py-3">Generation Speed</th>
<th className="px-6 py-3">Duration</th>
</tr>
</thead>
<tbody className="divide-y">
{sortedMetrics.map((metric) => (
<tr key={`metric_${metric.id}`}>
<td className="px-4 py-4 whitespace-nowrap text-sm">{metric.id + 1 /* un-zero index */}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.prompt_per_second)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
<tr key={metric.id} className="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
<td className="px-4 py-4">{metric.id + 1 /* un-zero index */}</td>
<td className="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
<td className="px-6 py-4">{metric.model}</td>
<td className="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
<td className="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
<td className="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
<td className="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
<td className="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
<td className="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
</tr>
))}
</tbody>
@@ -63,4 +93,28 @@ const ActivityPage = () => {
);
};
interface TooltipProps {
content: string;
}
const Tooltip: React.FC<TooltipProps> = ({ content }) => {
return (
<div className="relative group inline-block">
<div
className="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
opacity-0 group-hover:opacity-100 transition-opacity
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
>
{content}
<div
className="absolute bottom-full left-1/2 transform -translate-x-1/2
border-4 border-transparent border-b-gray-900"
></div>
</div>
</div>
);
};
export default ActivityPage;
+9 -9
View File
@@ -14,8 +14,8 @@ import { useTheme } from "../contexts/ThemeProvider";
const LogViewer = () => {
const { proxyLogs, upstreamLogs } = useAPI();
const { isNarrow } = useTheme();
const direction = isNarrow ? "vertical" : "horizontal";
const { screenWidth } = useTheme();
const direction = screenWidth === "xs" || screenWidth === "sm" ? "vertical" : "horizontal";
return (
<PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
@@ -115,19 +115,19 @@ export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
}, [filteredLogs]);
return (
<div className="bg-surface border border-border rounded-lg overflow-hidden flex flex-col h-full">
<div className="p-4 border-b border-border bg-secondary">
<div className="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full p-1">
<div className="p-4">
<div className="flex items-center justify-between">
<h3 className="m-0 text-lg p-0">{title}</h3>
<div className="flex gap-2 items-center">
<button className="btn" onClick={toggleFontSize}>
<button className="btn border-0" onClick={toggleFontSize}>
<RiFontSize />
</button>
<button className="btn" onClick={toggleWrapText}>
<button className="btn border-0" onClick={toggleWrapText}>
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
</button>
<button className="btn" onClick={toggleFilter}>
<button className="btn border-0" onClick={toggleFilter}>
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
</button>
</div>
@@ -139,7 +139,7 @@ export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
<div className="flex gap-2 items-center w-full">
<input
type="text"
className="w-full text-sm border p-2 rounded"
className="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
placeholder="Filter logs..."
value={filterRegex}
onChange={(e) => setFilterRegex(e.target.value)}
@@ -151,7 +151,7 @@ export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
</div>
)}
</div>
<div className="bg-background font-mono text-sm flex-1 overflow-hidden">
<div className="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
{filteredLogs}
</pre>
+409 -55
View File
@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
export default function ModelsPage() {
const { isNarrow } = useTheme();
@@ -37,13 +37,31 @@ export default function ModelsPage() {
}
function ModelsPanel() {
const { models, loadModel, unloadAllModels } = useAPI();
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
const { isNarrow } = useTheme();
const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
const [menuOpen, setMenuOpen] = useState(false);
const filteredModels = useMemo(() => {
return models.filter((model) => showUnlisted || !model.unlisted);
const { regularModels, peerModelsByPeerId } = useMemo(() => {
const filtered = models.filter((model) => showUnlisted || !model.unlisted);
const peerModels = filtered.filter((m) => m.peerID);
// Group peer models by peerID
const grouped = peerModels.reduce((acc, model) => {
const peerId = model.peerID || "unknown";
if (!acc[peerId]) {
acc[peerId] = [];
}
acc[peerId].push(model);
return acc;
}, {} as Record<string, typeof peerModels>);
return {
regularModels: filtered.filter((m) => !m.peerID),
peerModelsByPeerId: grouped,
};
}, [models, showUnlisted]);
const handleUnloadAllModels = useCallback(async () => {
@@ -66,104 +84,440 @@ function ModelsPanel() {
return (
<div className="card h-full flex flex-col">
<div className="shrink-0">
<h2>Models</h2>
<div className="flex justify-between">
<div className="flex gap-2">
<button className="btn flex items-center gap-2" onClick={toggleIdorName} style={{ lineHeight: "1.2" }}>
<RiSwapBoxFill /> {showIdorName === "id" ? "ID" : "Name"}
</button>
<div className="flex justify-between items-baseline">
<h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
{isNarrow && (
<div className="relative">
<button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
<RiMenuFill size="20" />
</button>
{menuOpen && (
<div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
toggleIdorName();
setMenuOpen(false);
}}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
</button>
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
setShowUnlisted(!showUnlisted);
setMenuOpen(false);
}}
>
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
</button>
<button
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onClick={() => {
handleUnloadAllModels();
setMenuOpen(false);
}}
disabled={isUnloading}
>
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
)}
</div>
)}
</div>
{!isNarrow && (
<div className="flex justify-between">
<div className="flex gap-2">
<button
className="btn text-base flex items-center gap-2"
onClick={toggleIdorName}
style={{ lineHeight: "1.2" }}
>
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
</button>
<button
className="btn text-base flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
>
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
</button>
</div>
<button
className="btn flex items-center gap-2"
onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
className="btn text-base flex items-center gap-2"
onClick={handleUnloadAllModels}
disabled={isUnloading}
>
{showUnlisted ? <RiEyeFill /> : <RiEyeOffFill />} unlisted
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
<button className="btn flex items-center gap-2" onClick={handleUnloadAllModels} disabled={isUnloading}>
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
</button>
</div>
)}
</div>
<div className="flex-1 overflow-y-auto">
<table className="w-full">
<thead className="sticky top-0 bg-card z-10">
<tr className="border-b border-primary bg-surface">
<th className="text-left p-2">{showIdorName === "id" ? "Model ID" : "Name"}</th>
<th className="text-left p-2"></th>
<th className="text-left p-2">State</th>
<tr className="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
<th>{showIdorName === "id" ? "Model ID" : "Name"}</th>
<th></th>
<th>State</th>
</tr>
</thead>
<tbody>
{filteredModels.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
<td className={`p-2 ${model.unlisted ? "text-txtsecondary" : ""}`}>
<a href={`/upstream/${model.id}/`} className={`underline`} target="_blank">
{regularModels.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
<td className={`${model.unlisted ? "text-txtsecondary" : ""}`}>
<a href={`/upstream/${model.id}/`} className="font-semibold" target="_blank">
{showIdorName === "id" ? model.id : model.name !== "" ? model.name : model.id}
</a>
{model.description !== "" && (
{!!model.description && (
<p className={model.unlisted ? "text-opacity-70" : ""}>
<em>{model.description}</em>
</p>
)}
</td>
<td className="p-2 w-[50px]">
<button
className="btn btn--sm"
disabled={model.state !== "stopped"}
onClick={() => loadModel(model.id)}
>
Load
</button>
<td className="w-12">
{model.state === "stopped" ? (
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
Load
</button>
) : (
<button
className="btn btn--sm"
onClick={() => unloadSingleModel(model.id)}
disabled={model.state !== "ready"}
>
Unload
</button>
)}
</td>
<td className="p-2 w-[75px]">
<span className={`status status--${model.state}`}>{model.state}</span>
<td className="w-20">
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
</td>
</tr>
))}
</tbody>
</table>
{Object.keys(peerModelsByPeerId).length > 0 && (
<>
<h3 className="mt-8 mb-2">Peer Models</h3>
{Object.entries(peerModelsByPeerId)
.sort(([a], [b]) => a.localeCompare(b))
.map(([peerId, models]) => (
<div key={peerId} className="mb-4">
<table className="w-full">
<thead className="sticky top-0 bg-card z-10">
<tr className="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
<th className="font-semibold">{peerId}</th>
</tr>
</thead>
<tbody>
{models.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
<td className={`pl-8 ${model.unlisted ? "text-txtsecondary" : ""}`}>
<span>{model.id}</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
))}
</>
)}
</div>
</div>
);
}
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
function TokenHistogram({ data }: { data: HistogramData }) {
const { bins, min, max, p50, p95, p99 } = data;
const maxCount = Math.max(...bins);
const height = 120;
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
// Use viewBox for responsive sizing
const viewBoxWidth = 600;
const chartWidth = viewBoxWidth - padding.left - padding.right;
const chartHeight = height - padding.top - padding.bottom;
const barWidth = chartWidth / bins.length;
const range = max - min;
// Calculate x position for a given value
const getXPosition = (value: number) => {
return padding.left + ((value - min) / range) * chartWidth;
};
return (
<div className="mt-2 w-full">
<svg viewBox={`0 0 ${viewBoxWidth} ${height}`} className="w-full h-auto" preserveAspectRatio="xMidYMid meet">
{/* Y-axis */}
<line
x1={padding.left}
y1={padding.top}
x2={padding.left}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="1"
opacity="0.3"
/>
{/* X-axis */}
<line
x1={padding.left}
y1={height - padding.bottom}
x2={viewBoxWidth - padding.right}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="1"
opacity="0.3"
/>
{/* Histogram bars */}
{bins.map((count, i) => {
const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0;
const x = padding.left + i * barWidth;
const y = height - padding.bottom - barHeight;
const binStart = min + i * data.binSize;
const binEnd = binStart + data.binSize;
return (
<g key={i}>
<rect
x={x}
y={y}
width={Math.max(barWidth - 1, 1)}
height={barHeight}
fill="currentColor"
opacity="0.6"
className="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
/>
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
</g>
);
})}
{/* Percentile lines */}
<line
x1={getXPosition(p50)}
y1={padding.top}
x2={getXPosition(p50)}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="2"
strokeDasharray="4 2"
opacity="0.7"
className="text-gray-600 dark:text-gray-400"
/>
<line
x1={getXPosition(p95)}
y1={padding.top}
x2={getXPosition(p95)}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="2"
strokeDasharray="4 2"
opacity="0.7"
className="text-orange-500 dark:text-orange-400"
/>
<line
x1={getXPosition(p99)}
y1={padding.top}
x2={getXPosition(p99)}
y2={height - padding.bottom}
stroke="currentColor"
strokeWidth="2"
strokeDasharray="4 2"
opacity="0.7"
className="text-green-500 dark:text-green-400"
/>
{/* X-axis labels */}
<text x={padding.left} y={height - 5} fontSize="10" fill="currentColor" opacity="0.6" textAnchor="start">
{min.toFixed(1)}
</text>
<text
x={viewBoxWidth - padding.right}
y={height - 5}
fontSize="10"
fill="currentColor"
opacity="0.6"
textAnchor="end"
>
{max.toFixed(1)}
</text>
{/* X-axis label */}
<text
x={padding.left + chartWidth / 2}
y={height - 2}
fontSize="10"
fill="currentColor"
opacity="0.6"
textAnchor="middle"
>
Tokens/Second Distribution
</text>
</svg>
</div>
);
}
function StatsPanel() {
const { metrics } = useAPI();
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
const [totalRequests, totalInputTokens, totalOutputTokens, tokenStats, histogramData] = useMemo(() => {
const totalRequests = metrics.length;
if (totalRequests === 0) {
return [0, 0, 0];
return [0, 0, 0, { p99: 0, p95: 0, p50: 0 }, null];
}
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
// Calculate token statistics using output_tokens and duration_ms
// Filter out metrics with invalid duration or output tokens
const validMetrics = metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
if (validMetrics.length === 0) {
return [totalRequests, totalInputTokens, totalOutputTokens, { p99: 0, p95: 0, p50: 0 }, null];
}
// Calculate tokens/second for each valid metric
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
// Sort for percentile calculation
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
// Calculate percentiles - showing speed thresholds where X% of requests are SLOWER (below)
// P99: 99% of requests are slower than this speed (99th percentile - fast requests)
// P95: 95% of requests are slower than this speed (95th percentile)
// P50: 50% of requests are slower than this speed (median)
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
// Create histogram data
const min = Math.min(...tokensPerSecond);
const max = Math.max(...tokensPerSecond);
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5))); // Adaptive bin count
const binSize = (max - min) / binCount;
const bins = Array(binCount).fill(0);
tokensPerSecond.forEach((value) => {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
});
const histogramData = {
bins,
min,
max,
binSize,
p99,
p95,
p50,
};
return [
totalRequests,
totalInputTokens,
totalOutputTokens,
{
p99: p99.toFixed(2),
p95: p95.toFixed(2),
p50: p50.toFixed(2),
},
histogramData,
];
}, [metrics]);
const nf = new Intl.NumberFormat();
return (
<div className="card">
<div className="rounded-lg overflow-hidden border border-gray-200">
<table className="w-full">
<tbody>
<div className="rounded-lg overflow-hidden border border-card-border-inner">
<table className="min-w-full divide-y divide-card-border-inner">
<thead className="bg-secondary">
<tr>
<th className="p-2 font-medium border-b border-gray-200 text-right">Requests</th>
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Processed</th>
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Generated</th>
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Tokens/Sec</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">
Requests
</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Processed
</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Generated
</th>
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Token Stats (tokens/sec)
</th>
</tr>
<tr>
<td className="p-2 text-right border-r border-gray-200">{totalRequests}</td>
<td className="p-2 text-right border-r border-gray-200">
{new Intl.NumberFormat().format(totalInputTokens)}
</thead>
<tbody className="bg-surface divide-y divide-card-border-inner">
<tr className="hover:bg-secondary">
<td className="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">{totalRequests}</td>
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div className="flex items-center gap-2">
<span className="text-sm font-medium">{nf.format(totalInputTokens)}</span>
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td className="p-2 text-right border-r border-gray-200">
{new Intl.NumberFormat().format(totalOutputTokens)}
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div className="flex items-center gap-2">
<span className="text-sm font-medium">{nf.format(totalOutputTokens)}</span>
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td className="px-4 py-4 border-l border-gray-200 dark:border-white/10">
<div className="space-y-3">
<div className="grid grid-cols-3 gap-2 items-center">
<div className="text-center">
<div className="text-xs text-gray-500 dark:text-gray-400">P50</div>
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{tokenStats.p50}
</div>
</div>
<div className="text-center">
<div className="text-xs text-gray-500 dark:text-gray-400">P95</div>
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{tokenStats.p95}
</div>
</div>
<div className="text-center">
<div className="text-xs text-gray-500 dark:text-gray-400">P99</div>
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{tokenStats.p99}
</div>
</div>
</div>
{histogramData && <TokenHistogram data={histogramData} />}
</div>
</td>
<td className="p-2 text-right">{avgTokensPerSecond}</td>
</tr>
</tbody>
</table>
+1
View File
@@ -15,6 +15,7 @@ export default defineConfig({
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080",
"/upstream": "http://localhost:8080",
"/unload": "http://localhost:8080",
},
},
});