Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dea98733c3 | |||
| bccce5fa19 | |||
| c968da1b73 | |||
| a883d68d4f | |||
| b1dec8b735 | |||
| 06523d8c1e | |||
| 86e9b93c37 | |||
| 3acace810f | |||
| 554d29e87d | |||
| 3567b7df08 | |||
| 38738525c9 | |||
| c0fc858193 | |||
| b429349e8a | |||
| eab2efd7b5 | |||
| 6aedbe121a | |||
| b24467ab89 | |||
| 12b69fb718 | |||
| f91a8b2462 | |||
| a89b803d4a | |||
| f852689104 | |||
| e250e71e59 | |||
| d18dc26d01 | |||
| 8357714421 | |||
| c07179d6e2 | |||
| 7ff50631e0 | |||
| 9fc0431531 | |||
| 6516532568 | |||
| d58a8b85bf | |||
| caf9e98b1e | |||
| 539278343b | |||
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 | |||
| 9e3d491c85 | |||
| 1a84926505 | |||
| fc3bb716df | |||
| c36986fef6 | |||
| 558801db1a | |||
| b21dee27c1 | |||
| f58c8c8ec5 | |||
| 954e2dee73 | |||
| a533aec736 | |||
| 97b17fc47d | |||
| 2457840698 | |||
| 7f55494151 | |||
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc | |||
| 57803fd3aa | |||
| c55d0cc842 | |||
| 7acbaf4712 | |||
| fcc5ad135a | |||
| 305e5a0031 | |||
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc |
@@ -1,11 +1,13 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: Something is not working as expected...
|
||||
about: I found a defect
|
||||
title: ''
|
||||
labels: bug
|
||||
labels: 'unconfirmed bug'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
> [!IMPORTANT]
|
||||
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
name: Validate JSON Schema
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
validate-schema:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Validate JSON Schema
|
||||
run: |
|
||||
# Check if the file is valid JSON
|
||||
if ! jq empty config-schema.json 2>/dev/null; then
|
||||
echo "Error: config-schema.json is not valid JSON"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate that it's a valid JSON Schema
|
||||
# Check for required $schema field
|
||||
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
|
||||
echo "Warning: config-schema.json should have a \$schema field"
|
||||
fi
|
||||
|
||||
# Check that it has either properties or definitions
|
||||
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
|
||||
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
|
||||
fi
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
@@ -22,6 +22,13 @@ jobs:
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# Only run in this linux based runner
|
||||
- name: Check Formatting
|
||||
run: |
|
||||
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||
gofmt -l . | grep -v 'event/.*_test.go'
|
||||
exit 1
|
||||
fi
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
|
||||
@@ -7,6 +7,10 @@ on:
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag version to release (e.g. v144)'
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -20,15 +24,15 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
|
||||
-
|
||||
name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23' # or your preferred version
|
||||
node-version: '23'
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
run: |
|
||||
@@ -46,4 +50,30 @@ jobs:
|
||||
version: '~> v2'
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
trigger-tap-update:
|
||||
runs-on: ubuntu-latest
|
||||
needs: goreleaser
|
||||
steps:
|
||||
- name: "Resolve tag to dispatch"
|
||||
id: tag
|
||||
run: |
|
||||
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: "Trigger tap repository update"
|
||||
uses: peter-evans/repository-dispatch@v2
|
||||
with:
|
||||
token: ${{ secrets.TAP_REPO_PAT }}
|
||||
repository: mostlygeek/homebrew-llama-swap
|
||||
event-type: new-release
|
||||
client-payload: |
|
||||
{
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
@@ -4,3 +4,4 @@ build/
|
||||
dist/
|
||||
.vscode
|
||||
.DS_Store
|
||||
.dev/
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
# Project: llama-swap
|
||||
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and react for UI (ui/)
|
||||
|
||||
## Testing
|
||||
|
||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
### Plan Improvements
|
||||
|
||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||
|
||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||
|
||||
- Identify any inconsistencies.
|
||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||
- Plans should have at least these sections:
|
||||
- Title - very short, describes changes
|
||||
- Overview: A more detailed summary of goal and outcomes desired
|
||||
- Design Requirements: Detailed descriptions of what needs to be done
|
||||
- Testing Plan: Tests to be implemented
|
||||
- Checklist: A detailed list of changes to be made
|
||||
|
||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||
|
||||
### Implementation of plans
|
||||
|
||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||
|
||||
## General Rules
|
||||
|
||||
- when summarizing changes only include details that require further action (action items)
|
||||
- when there are no action items, just say "Done."
|
||||
@@ -23,11 +23,17 @@ proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -v -count=1 ./proxy
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/...
|
||||
staticcheck ./proxy/... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -v -count=1 ./proxy
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
@@ -45,6 +51,7 @@ mac: ui
|
||||
linux: ui
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
@@ -54,12 +61,12 @@ windows: ui
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
|
||||
|
||||
simple-responder-windows:
|
||||
@echo "Building simple responder for windows"
|
||||
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
|
||||
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
@@ -79,5 +86,11 @@ release:
|
||||
echo "tagging new version: $$new_tag"; \
|
||||
git tag "$$new_tag";
|
||||
|
||||
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
|
||||
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
|
||||
wol-proxy: $(BUILD_DIR)
|
||||
@echo "Building wol-proxy"
|
||||
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean ui mac linux windows simple-responder
|
||||
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
||||
|
||||
@@ -5,148 +5,191 @@
|
||||
|
||||
# llama-swap
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- future proof, upgrade your inference servers at any time.
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`, `v1/reranking`, `rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
- `/completion` - for completion endpoint
|
||||
- ✅ llama-swap API
|
||||
- `/ui` - web UI
|
||||
- `/log` - remote log monitoring
|
||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Docker and Podman support
|
||||
- ✅ Full control over server settings per model
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ Customizable
|
||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||
|
||||
## How does llama-swap work?
|
||||
### Web UI
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
llama-swap includes a real time web interface for monitoring logs and controlling models:
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## config.yaml
|
||||
|
||||
llama-swap is managed entirely through a yaml configuration file.
|
||||
|
||||
It can be very minimal to start:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"qwen2.5":
|
||||
cmd: |
|
||||
/path/to/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
```
|
||||
|
||||
However, there are many more capabilities that llama-swap supports:
|
||||
|
||||
- `groups` to run multiple models at once
|
||||
- `ttl` to automatically unload models
|
||||
- `macros` for reusable snippets
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` to pass custom environment variables to inference servers
|
||||
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `healthCheckTimeout` to control model startup wait times
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
|
||||
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
|
||||
|
||||
## Web UI
|
||||
|
||||
llama-swap ships with a real time web interface to monitor logs and status of models:
|
||||
|
||||
<img width="1786" height="1334" alt="image" src="https://github.com/user-attachments/assets/d6258cb9-1dad-40db-828f-2be860aec8fe" />
|
||||
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
|
||||
|
||||
|
||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
The Activity Page shows recent requests:
|
||||
|
||||
Docker is the quickest way to try out llama-swap:
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||
|
||||
## Installation
|
||||
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. WinGet
|
||||
4. From release binaries
|
||||
5. From source
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
|
||||
|
||||
```shell
|
||||
# use CPU inference comes with the example config above
|
||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# qwen2.5 0.5B
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
|
||||
# SmolLM2 135M
|
||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer no-key" \
|
||||
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||
jq -r '.choices[0].message.content'
|
||||
# run with a custom configuration and models directory
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Docker images are built nightly for cuda, intel, vulcan, etc ...</summary>
|
||||
|
||||
They include:
|
||||
|
||||
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
||||
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
||||
- `ghcr.io/mostlygeek/llama-swap:intel`
|
||||
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
||||
- ROCm disabled until fixed in llama.cpp container
|
||||
|
||||
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
||||
|
||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||
<summary>
|
||||
more examples
|
||||
</summary>
|
||||
|
||||
```shell
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
# pull latest images per platform
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cpu
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:intel
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||
|
||||
# tagged llama-swap, platform and llama-server version images
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||
### Homebrew Install (macOS/Linux)
|
||||
|
||||
Pre-built binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
||||
```shell
|
||||
brew tap mostlygeek/llama-swap
|
||||
brew install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
|
||||
Available flags:
|
||||
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||
- `--version`: Show version information and exit.
|
||||
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
||||
### WinGet Install (Windows)
|
||||
|
||||
> [!NOTE]
|
||||
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
|
||||
|
||||
```shell
|
||||
# install
|
||||
C:\> winget install llama-swap
|
||||
|
||||
# upgrade
|
||||
C:\> winget upgrade llama-swap
|
||||
```
|
||||
|
||||
### Pre-built Binaries
|
||||
|
||||
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Build requires golang and nodejs for the user interface.
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. Building requires Go and Node.js (for UI).
|
||||
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
1. look in the `build/` subdirectory for the llama-swap binary
|
||||
|
||||
## Monitoring Logs
|
||||
## Configuration
|
||||
|
||||
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
|
||||
```yaml
|
||||
# minimum viable config.yaml
|
||||
|
||||
CLI access is also supported:
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||
```
|
||||
|
||||
That's all you need to get started:
|
||||
|
||||
1. `models` - holds all model configurations
|
||||
2. `model1` - the ID used in API calls
|
||||
3. `cmd` - the command to run to start the server.
|
||||
4. `${PORT}` - an automatically assigned port number
|
||||
|
||||
Almost all configuration settings are optional and can be added one step at a time:
|
||||
|
||||
- Advanced features
|
||||
- `groups` to run multiple models at once
|
||||
- `hooks` to run things on startup
|
||||
- `macros` reusable snippets
|
||||
- Model customization
|
||||
- `ttl` to automatically unload models
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` to pass custom environment variables to inference servers
|
||||
- `cmdStop` gracefully stop Docker/Podman containers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
- `filters` rewrite parts of requests before sending to the upstream server
|
||||
|
||||
See the [configuration documentation](docs/configuration.md) for all options.
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Reverse Proxy Configuration (nginx)
|
||||
|
||||
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks Server‑Sent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
|
||||
|
||||
Recommended nginx configuration snippets:
|
||||
|
||||
```nginx
|
||||
# SSE for UI events/logs
|
||||
location /api/events {
|
||||
proxy_pass http://your-llama-swap-backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Streaming chat completions (stream=true)
|
||||
location /v1/chat/completions {
|
||||
proxy_pass http://your-llama-swap-backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
```
|
||||
|
||||
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
|
||||
|
||||
## Monitoring Logs on the CLI
|
||||
|
||||
```shell
|
||||
# sends up to the last 10KB of logs
|
||||
@@ -172,15 +215,11 @@ curl -Ns 'http://host/logs/stream?no-history'
|
||||
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
|
||||
## Contributors
|
||||
<a href="https://github.com/mostlygeek/llama-swap/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=mostlygeek/llama-swap" />
|
||||
</a>
|
||||
|
||||
Made with [contrib.rocks](https://contrib.rocks).
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
|
||||
|
||||
## Star History
|
||||
|
||||
> [!NOTE]
|
||||
> ⭐️ Star this project to help others discover it!
|
||||
|
||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
# Add Model Metadata Support with Typed Macros
|
||||
|
||||
## Overview
|
||||
|
||||
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||
|
||||
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. Enhanced Macro System
|
||||
|
||||
**Current State:**
|
||||
|
||||
- Macros are defined as `map[string]string` at both global and model levels
|
||||
- Only string substitution is supported
|
||||
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Implement type-preserving macro substitution:
|
||||
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||
- Add validation to ensure macro values are scalar types only
|
||||
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||
|
||||
**Implementation Details:**
|
||||
|
||||
- Create a generic helper function to perform macro substitution that:
|
||||
- Takes a value of type `any`
|
||||
- Recursively processes maps, slices, and scalar values
|
||||
- Replaces `${macro_name}` patterns with macro values
|
||||
- Preserves types for direct substitution
|
||||
- Converts to strings for interpolated substitution
|
||||
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||
- Maintain backward compatibility with existing string-only macros
|
||||
|
||||
### 2. Metadata Field in ModelConfig
|
||||
|
||||
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||
- Apply macro substitution to metadata values during config loading
|
||||
|
||||
**Schema Requirements:**
|
||||
|
||||
- Metadata is optional (default: empty/nil map)
|
||||
- Supports nested structures (objects within objects, arrays, etc.)
|
||||
- All string values within metadata undergo macro substitution
|
||||
- Type preservation rules apply as described above
|
||||
|
||||
### 3. Macro Substitution in Metadata
|
||||
|
||||
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||
|
||||
**Process Flow:**
|
||||
|
||||
1. After loading YAML configuration
|
||||
2. After model-level and global macro merging
|
||||
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||
5. Process recursively through all nested structures
|
||||
|
||||
**Substitution Rules:**
|
||||
|
||||
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||
- `temperature: ${temp}` → keeps float type from temp macro
|
||||
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||
- Arrays and nested objects are processed recursively
|
||||
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||
|
||||
### 4. API Response Updates
|
||||
|
||||
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||
|
||||
**Current Behavior:**
|
||||
|
||||
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||
- Optionally includes: `name`, `description`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add metadata to each model record under the key `llamaswap_meta`
|
||||
- Only include `llamaswap_meta` if metadata is non-empty
|
||||
- Preserve all types when marshaling to JSON
|
||||
- Maintain existing sorting by model ID
|
||||
|
||||
**Example Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "llama",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "llama-swap",
|
||||
"name": "llama 3.1 8B",
|
||||
"description": "A small but capable model",
|
||||
"llamaswap_meta": {
|
||||
"port": 10001,
|
||||
"temperature": 0.7,
|
||||
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||
"an_obj": {
|
||||
"a": "1",
|
||||
"b": 2,
|
||||
"c": [0.7, false, "model: llama"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Validation and Error Handling
|
||||
|
||||
**Macro Validation:**
|
||||
|
||||
- Extend `validateMacro()` to accept values of type `any`
|
||||
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Reject complex types (maps, slices, structs) as macro values
|
||||
- Maintain existing validation for macro names and lengths
|
||||
|
||||
**Configuration Loading:**
|
||||
|
||||
- Fail fast if unknown macros are found in metadata
|
||||
- Provide clear error messages indicating which model and field contains errors
|
||||
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### Test 1: Model-Level Macros with Different Types
|
||||
|
||||
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define model with macros of each scalar type
|
||||
- Verify metadata correctly substitutes and preserves types
|
||||
- Test direct substitution (`port: ${PORT}`)
|
||||
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||
- Verify nested objects and arrays work correctly
|
||||
|
||||
### Test 2: Global and Model Macro Precedence
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define same macro at global and model level with different types
|
||||
- Verify model-level macro takes precedence
|
||||
- Test metadata uses correct macro value
|
||||
- Verify type is preserved from the winning macro
|
||||
|
||||
### Test 3: Macro Validation
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Test that complex types (maps, arrays) are rejected as macro values
|
||||
- Verify error message includes: macro name and type that was rejected
|
||||
- Test that scalar types (string, int, float, bool) are accepted
|
||||
- Each type should load without error
|
||||
- Test macro name validation still works with `any` types
|
||||
- Invalid characters, reserved names, length limits should still be enforced
|
||||
|
||||
### Test 4: Metadata in API Response
|
||||
|
||||
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
|
||||
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Model with metadata → verify `llamaswap_meta` key appears
|
||||
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||
- Verify all types are correctly marshaled to JSON
|
||||
- Verify nested structures are preserved
|
||||
- Verify macro substitution has occurred before serialization
|
||||
|
||||
### Test 5: Unknown Macros in Metadata
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Use undefined macro in metadata
|
||||
- Verify configuration loading fails with clear error
|
||||
- Error should indicate model name and that macro is undefined
|
||||
|
||||
### Test 6: Recursive Substitution
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Metadata with deeply nested structures
|
||||
- Arrays containing objects with macros
|
||||
- Objects containing arrays with macros
|
||||
- Mixed string interpolation and direct substitution at various nesting levels
|
||||
|
||||
## Checklist
|
||||
|
||||
### Configuration Schema Changes
|
||||
|
||||
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||
- [x] Add validation logic to ensure macro values are scalar types only
|
||||
|
||||
### Macro Substitution Logic
|
||||
|
||||
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||
- [x] Implement type-preserving direct substitution logic
|
||||
- [x] Implement string interpolation with type conversion
|
||||
- [x] Handle maps: recursively process all values
|
||||
- [x] Handle slices: recursively process all elements
|
||||
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||
|
||||
### API Response Changes
|
||||
|
||||
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||
- [x] Verify JSON marshaling preserves all types correctly
|
||||
|
||||
### Testing - Config Package
|
||||
|
||||
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
### Testing - Model Config Package
|
||||
|
||||
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||
- [x] Test metadata with various scalar types
|
||||
- [x] Test metadata with nested objects and arrays
|
||||
|
||||
### Testing - Proxy Manager
|
||||
|
||||
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
- [x] Add test case for model with metadata
|
||||
- [x] Add test case for model without metadata
|
||||
- [x] Verify `llamaswap_meta` key presence/absence
|
||||
- [x] Verify type preservation in JSON output
|
||||
- [x] Verify macro substitution has occurred
|
||||
|
||||
### Documentation
|
||||
|
||||
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||
- [x] No additional documentation needed per project instructions
|
||||
|
||||
## Known Issues and Considerations
|
||||
|
||||
### Inconsistencies
|
||||
|
||||
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||
|
||||
### Design Decisions
|
||||
|
||||
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||
|
||||
- Avoids potential collisions with OpenAI API standard fields
|
||||
- Makes it clear this is llama-swap specific metadata
|
||||
- Easier for clients to distinguish standard vs. custom fields
|
||||
|
||||
2. **Why support nested structures?**
|
||||
|
||||
- Provides maximum flexibility for users
|
||||
- Aligns with the schemaless design principle
|
||||
- Example config already demonstrates this capability
|
||||
|
||||
3. **Why validate macro types?**
|
||||
- Prevents confusing behavior (e.g., substituting a map)
|
||||
- Makes configuration errors explicit at load time
|
||||
- Simpler implementation and testing
|
||||
@@ -0,0 +1,397 @@
|
||||
# Improve macro-in-macro support
|
||||
|
||||
**Status: COMPLETED ✅**
|
||||
|
||||
## Title
|
||||
|
||||
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||
|
||||
## Overview
|
||||
|
||||
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||
|
||||
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||
|
||||
**Outcomes:**
|
||||
- Macros can reference other macros defined earlier in the config
|
||||
- Macro substitution is deterministic and order-dependent
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||
- All existing tests pass
|
||||
- New tests validate substitution order and self-reference detection
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. YAML Parsing Strategy
|
||||
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||
|
||||
### 2. Data Structure Changes
|
||||
|
||||
#### Current Implementation (config.go:19)
|
||||
```go
|
||||
type MacroList map[string]any
|
||||
```
|
||||
|
||||
#### New Implementation
|
||||
```go
|
||||
type MacroList []MacroEntry
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
```
|
||||
|
||||
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||
|
||||
### 3. Macro Substitution Order Rules
|
||||
|
||||
The substitution must follow this hierarchy (from most specific to least):
|
||||
|
||||
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||
3. **Global macros** (first): Defined at config root level
|
||||
|
||||
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||
- The last macro defined is substituted first
|
||||
- This allows later macros to reference earlier ones
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
|
||||
### 4. Macro Reference Rules
|
||||
|
||||
**Allowed:**
|
||||
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||
- Model macros can reference global macros
|
||||
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||
|
||||
**Prohibited:**
|
||||
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||
- Macro cannot reference macros defined **after** it
|
||||
- No circular references (prevented by single-pass, ordered substitution)
|
||||
|
||||
### 5. Validation Requirements
|
||||
|
||||
Add validation to detect:
|
||||
- **Self-references:** Macro value contains reference to its own name
|
||||
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||
|
||||
Error messages should be clear:
|
||||
```
|
||||
macro 'foo' contains self-reference
|
||||
unknown macro '${bar}' in model.cmd
|
||||
```
|
||||
|
||||
### 6. Implementation Changes
|
||||
|
||||
#### Files to Modify
|
||||
|
||||
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||
- Line 19: Change `MacroList` type definition
|
||||
- Line 69: Update `Macros MacroList` field
|
||||
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||
- Line 175-188: Update model-level macro validation
|
||||
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||
- Line 389-415: Update `validateMacro` to detect self-references
|
||||
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||
|
||||
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||
- Line 33: Update `Macros MacroList` field type
|
||||
|
||||
3. **All test files**
|
||||
- Update test fixtures to use ordered macro definitions
|
||||
- Ensure tests specify macro order explicitly
|
||||
|
||||
#### Core Algorithm
|
||||
|
||||
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||
|
||||
```go
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||
|
||||
// Add global macros first
|
||||
for _, entry := range config.Macros {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
|
||||
// Add model macros (can override global)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Add reserved MODEL_ID macro at the end
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Check if PORT macro is needed
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro to the end (highest priority)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Add this new helper function to replace `substituteMetadataMacros`:
|
||||
|
||||
```go
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Self-Reference Detection
|
||||
|
||||
Add to `validateMacro` function:
|
||||
|
||||
```go
|
||||
func validateMacro(name string, value any) error {
|
||||
// ... existing validation ...
|
||||
|
||||
// Check for self-reference
|
||||
if str, ok := value.(string); ok {
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(str, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### 1. Migration Tests
|
||||
- **Test:** All existing macro tests still pass after YAML library migration
|
||||
- **Files:** All `*_test.go` files with macro tests
|
||||
|
||||
### 2. Macro Order Tests
|
||||
|
||||
#### Test: Macro-in-macro substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "echo ${B}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||
|
||||
#### Test: LIFO substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "load ${full}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||
|
||||
#### Test: Model macro overrides global
|
||||
```yaml
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: "echo ${msg}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||
|
||||
### 3. Reserved Macro Tests
|
||||
|
||||
#### Test: MODEL_ID substituted in macro
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: "${podman-llama} -m model.gguf"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||
|
||||
### 4. Error Detection Tests
|
||||
|
||||
#### Test: Self-reference detection
|
||||
```yaml
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
```
|
||||
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||
|
||||
#### Test: Undefined macro reference
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
```
|
||||
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||
|
||||
### 5. Regression Tests
|
||||
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||
- Ensure all pass without modification (except test fixtures if needed)
|
||||
|
||||
## Checklist
|
||||
|
||||
### Phase 1: Data Structure Changes
|
||||
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||
- [ ] Implement helper functions:
|
||||
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||
|
||||
### Phase 2: Macro Validation Updates
|
||||
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||
- [ ] Test self-reference detection with new test case
|
||||
|
||||
### Phase 3: Macro Substitution Algorithm
|
||||
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||
- [ ] Substitute in metadata within same loop
|
||||
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||
|
||||
### Phase 4: Testing
|
||||
- [ ] Run `make test-dev` - fix any static checking errors
|
||||
- [ ] Add test: macro-in-macro basic substitution
|
||||
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||
- [ ] Add test: MODEL_ID in global macro used by model
|
||||
- [ ] Add test: PORT in global macro used by model
|
||||
- [ ] Add test: model macro overrides global macro in substitution
|
||||
- [ ] Add test: self-reference detection error
|
||||
- [ ] Add test: undefined macro reference error
|
||||
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||
|
||||
### Phase 5: Documentation
|
||||
- [ ] Update plan status in this file (mark completed)
|
||||
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||
- [ ] Verify no new error messages need user documentation
|
||||
|
||||
## Bug Example (Original Issue)
|
||||
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": >
|
||||
podman run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||
|
||||
"standard-options": >
|
||||
--no-mmap --jinja
|
||||
|
||||
"kv8": >
|
||||
-fa on -ctk q8_0 -ctv q8_0
|
||||
```
|
||||
|
||||
**Current Bug:**
|
||||
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||
|
||||
**After Fix:**
|
||||
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||
// to make sure all the requests are accounted for.
|
||||
//
|
||||
// requests can be sent in parallel, and the tool will report the results.
|
||||
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// ----- CLI arguments ----------------------------------------------------
|
||||
var (
|
||||
baseurl string
|
||||
modelName string
|
||||
totalRequests int
|
||||
parallelization int
|
||||
)
|
||||
|
||||
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||
flag.Parse()
|
||||
|
||||
if baseurl == "" || modelName == "" {
|
||||
fmt.Println("Error: both -baseurl and -model are required.")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if totalRequests <= 0 {
|
||||
fmt.Println("Error: -requests must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
if parallelization <= 0 {
|
||||
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// ----- HTTP client -------------------------------------------------------
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// ----- Tracking response codes -------------------------------------------
|
||||
statusCounts := make(map[int]int) // map[statusCode]count
|
||||
var mu sync.Mutex // protects statusCounts
|
||||
|
||||
// ----- Request queue (buffered channel) ----------------------------------
|
||||
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||
|
||||
// Goroutine to fill the request queue
|
||||
go func() {
|
||||
for i := 0; i < totalRequests; i++ {
|
||||
requests <- i + 1
|
||||
}
|
||||
close(requests)
|
||||
}()
|
||||
|
||||
// ----- Worker pool -------------------------------------------------------
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < parallelization; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for reqID := range requests {
|
||||
// Build request payload as a single line JSON string
|
||||
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||
|
||||
// Send POST request
|
||||
req, err := http.NewRequest(http.MethodPost,
|
||||
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||
bytes.NewReader([]byte(payload)))
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
// Record status code
|
||||
mu.Lock()
|
||||
statusCounts[resp.StatusCode]++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
// ----- Status ticker (prints every second) -------------------------------
|
||||
done := make(chan struct{})
|
||||
tickerDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
startTime := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mu.Lock()
|
||||
// Compute how many requests have completed so far
|
||||
completed := 0
|
||||
for _, cnt := range statusCounts {
|
||||
completed += cnt
|
||||
}
|
||||
// Calculate duration and progress
|
||||
duration := time.Since(startTime)
|
||||
progress := completed * 100 / totalRequests
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||
mu.Unlock()
|
||||
case <-done:
|
||||
duration := time.Since(startTime)
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||
close(tickerDone)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all workers to finish
|
||||
wg.Wait()
|
||||
close(done) // stops the status-update goroutine
|
||||
<-tickerDone // give ticker time to finish / print
|
||||
|
||||
// ----- Summary ------------------------------------------------------------
|
||||
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||
mu.Lock()
|
||||
for code, cnt := range statusCounts {
|
||||
if code == -1 {
|
||||
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||
} else {
|
||||
fmt.Printf("%d : %d\n", code, cnt)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
@@ -153,6 +153,19 @@ func main() {
|
||||
|
||||
})
|
||||
|
||||
// llama-server compatibility: /completion
|
||||
r.POST("/completion", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// issue #41
|
||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||
// Parse the multipart form
|
||||
@@ -0,0 +1,27 @@
|
||||
# wol-proxy
|
||||
|
||||
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
|
||||
|
||||
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
|
||||
|
||||
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
|
||||
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
# minimal
|
||||
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
|
||||
|
||||
# everything
|
||||
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
|
||||
# use debug log level
|
||||
-log debug \
|
||||
# altenerative listening port
|
||||
-listen localhost:9999 \
|
||||
# seconds to hold requests waiting for upstream to be ready
|
||||
-timeout 30
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
`GET /status` - that's it. Everything else is proxied to the upstream server.
|
||||
@@ -0,0 +1,64 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Loading...</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.loader {
|
||||
text-align: center;
|
||||
}
|
||||
.stats {
|
||||
font-size: 18px;
|
||||
color: #333;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.stats-label {
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="loader">
|
||||
<p>Waking up upstream server...</p>
|
||||
<div class="stats">
|
||||
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
|
||||
<div><span id="attempts"> </span></div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
var startTime = Date.now();
|
||||
var attempts = 0;
|
||||
|
||||
setInterval(function() {
|
||||
var elapsed = (Date.now() - startTime) / 1000;
|
||||
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
|
||||
}, 100);
|
||||
|
||||
// Check status every second
|
||||
setInterval(function() {
|
||||
attempts++;
|
||||
var dots = '.'.repeat((attempts % 10) || 10);
|
||||
document.getElementById('attempts').textContent = dots;
|
||||
|
||||
fetch('/status')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(t) {
|
||||
if (t.indexOf('status: ready') !== -1) {
|
||||
location.reload();
|
||||
}
|
||||
})
|
||||
.catch(function() {});
|
||||
}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,333 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed index.html
|
||||
var loadingPageHTML string
|
||||
|
||||
var (
|
||||
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
|
||||
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
|
||||
flagListen = flag.String("listen", ":8080", "listen address to listen on")
|
||||
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
|
||||
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
switch *flagLog {
|
||||
case "debug":
|
||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||
case "info":
|
||||
slog.SetLogLoggerLevel(slog.LevelInfo)
|
||||
case "warn":
|
||||
slog.SetLogLoggerLevel(slog.LevelWarn)
|
||||
case "error":
|
||||
slog.SetLogLoggerLevel(slog.LevelError)
|
||||
default:
|
||||
slog.Error("invalid log level", "logLevel", *flagLog)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate flags
|
||||
if *flagListen == "" {
|
||||
slog.Error("listen address is required")
|
||||
return
|
||||
}
|
||||
|
||||
if *flagMac == "" {
|
||||
slog.Error("mac address is required")
|
||||
return
|
||||
}
|
||||
|
||||
if *flagTimeout < 1 {
|
||||
slog.Error("timeout must be greater than 0")
|
||||
return
|
||||
}
|
||||
|
||||
var upstreamURL *url.URL
|
||||
var err error
|
||||
// validate mac address
|
||||
if _, err = net.ParseMAC(*flagMac); err != nil {
|
||||
slog.Error("invalid mac address", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *flagUpstream == "" {
|
||||
slog.Error("upstream proxy address is required")
|
||||
return
|
||||
} else {
|
||||
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
|
||||
if err != nil {
|
||||
slog.Error("error parsing upstream url", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxy := newProxy(upstreamURL)
|
||||
server := &http.Server{
|
||||
Addr: *flagListen,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// start the server
|
||||
go func() {
|
||||
slog.Info("server starting on", "address", *flagListen)
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
slog.Error("error starting server", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// graceful shutdown
|
||||
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
<-ctx.Done()
|
||||
server.Close()
|
||||
}
|
||||
|
||||
type upstreamStatus string
|
||||
|
||||
const (
|
||||
notready upstreamStatus = "not ready"
|
||||
ready upstreamStatus = "ready"
|
||||
)
|
||||
|
||||
type proxyServer struct {
|
||||
upstreamProxy *httputil.ReverseProxy
|
||||
failCount int
|
||||
statusMutex sync.RWMutex
|
||||
status upstreamStatus
|
||||
}
|
||||
|
||||
func newProxy(url *url.URL) *proxyServer {
|
||||
p := httputil.NewSingleHostReverseProxy(url)
|
||||
proxy := &proxyServer{
|
||||
upstreamProxy: p,
|
||||
status: notready,
|
||||
failCount: 0,
|
||||
}
|
||||
|
||||
// start a goroutine to monitor upstream status via SSE
|
||||
go func() {
|
||||
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
|
||||
client := &http.Client{
|
||||
Timeout: 0, // No timeout for SSE connection
|
||||
}
|
||||
|
||||
waitDuration := 10 * time.Second
|
||||
|
||||
for {
|
||||
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
|
||||
|
||||
req, err := http.NewRequest("GET", eventsUrl, nil)
|
||||
if err != nil {
|
||||
slog.Warn("failed to create SSE request", "error", err)
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(waitDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Error("failed to connect to SSE endpoint", "error", err)
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully connected to SSE endpoint
|
||||
slog.Info("connected to SSE endpoint, upstream ready")
|
||||
proxy.setStatus(ready)
|
||||
proxy.resetFailures()
|
||||
|
||||
// Read from the SSE stream to detect disconnection
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
// use a fairly large buffer to avoid scanner errors when reading large SSE events
|
||||
buf := make([]byte, 0, 1024*1024*2)
|
||||
scanner.Buffer(buf, 1024*1024*2)
|
||||
events := 0
|
||||
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||
fmt.Print("Events: ")
|
||||
}
|
||||
for scanner.Scan() {
|
||||
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||
// Just read the events to keep connection alive
|
||||
// We don't need to process the event data
|
||||
events++
|
||||
fmt.Printf("%d, ", events)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
if err := scanner.Err(); err != nil {
|
||||
slog.Error("error reading from SSE stream", "error", err)
|
||||
}
|
||||
|
||||
// Connection closed or error occurred
|
||||
_ = resp.Body.Close()
|
||||
slog.Info("SSE connection closed, upstream not ready")
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
|
||||
// Wait before reconnecting
|
||||
time.Sleep(waitDuration)
|
||||
}
|
||||
}()
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" && r.URL.Path == "/status" {
|
||||
status := string(p.getStatus())
|
||||
failCount := p.getFailures()
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprintf(w, "status: %s\n", status)
|
||||
fmt.Fprintf(w, "failures: %d\n", failCount)
|
||||
return
|
||||
}
|
||||
|
||||
if p.getStatus() == notready {
|
||||
path := r.URL.Path
|
||||
if strings.HasPrefix(path, "/api/events") {
|
||||
slog.Debug("Skipping wake up", "req", path)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
|
||||
if err := sendMagicPacket(*flagMac); err != nil {
|
||||
slog.Warn("failed to send magic WoL packet", "error", err)
|
||||
}
|
||||
|
||||
// For root or UI path requests, return loading page with status polling
|
||||
// the web page will do the polling and redirect when ready
|
||||
if path == "/" || strings.HasPrefix(path, "/ui/") {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, loadingPageHTML)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
|
||||
defer cancel()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
slog.Info("timeout waiting for upstream to be ready")
|
||||
http.Error(w, "timeout", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.getStatus() == ready {
|
||||
ticker.Stop()
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.upstreamProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *proxyServer) getStatus() upstreamStatus {
|
||||
p.statusMutex.RLock()
|
||||
defer p.statusMutex.RUnlock()
|
||||
return p.status
|
||||
}
|
||||
|
||||
func (p *proxyServer) setStatus(status upstreamStatus) {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.status = status
|
||||
}
|
||||
|
||||
func (p *proxyServer) incFail(num int) {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.failCount += num
|
||||
}
|
||||
|
||||
func (p *proxyServer) getFailures() int {
|
||||
p.statusMutex.RLock()
|
||||
defer p.statusMutex.RUnlock()
|
||||
return p.failCount
|
||||
}
|
||||
|
||||
func (p *proxyServer) resetFailures() {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.failCount = 0
|
||||
}
|
||||
|
||||
func sendMagicPacket(macAddr string) error {
|
||||
hwAddr, err := net.ParseMAC(macAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hwAddr) != 6 {
|
||||
return errors.New("invalid MAC address")
|
||||
}
|
||||
|
||||
// Create the magic packet.
|
||||
packet := make([]byte, 102)
|
||||
// Add 6 bytes of 0xFF.
|
||||
for i := 0; i < 6; i++ {
|
||||
packet[i] = 0xFF
|
||||
}
|
||||
// Repeat the MAC address 16 times.
|
||||
for i := 1; i <= 16; i++ {
|
||||
copy(packet[i*6:], hwAddr)
|
||||
}
|
||||
|
||||
// Send the packet using UDP.
|
||||
addr := net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: 9,
|
||||
}
|
||||
conn, err := net.DialUDP("udp", nil, &addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
{
|
||||
"$schema": "https://json-schema.org/draft-07/schema#",
|
||||
"$id": "llama-swap-config-schema.json",
|
||||
"title": "llama-swap configuration",
|
||||
"description": "Configuration file for llama-swap",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"models"
|
||||
],
|
||||
"definitions": {
|
||||
"macros": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"minLength": 0,
|
||||
"maxLength": 1024
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"propertyNames": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 64,
|
||||
"pattern": "^[a-zA-Z0-9_-]+$",
|
||||
"not": {
|
||||
"enum": [
|
||||
"PORT",
|
||||
"MODEL_ID"
|
||||
]
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"healthCheckTimeout": {
|
||||
"type": "integer",
|
||||
"minimum": 15,
|
||||
"default": 120,
|
||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||
},
|
||||
"logLevel": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"debug",
|
||||
"info",
|
||||
"warn",
|
||||
"error"
|
||||
],
|
||||
"default": "info",
|
||||
"description": "Sets the logging value. Valid values: debug, info, warn, error."
|
||||
},
|
||||
"logTimeFormat": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"",
|
||||
"ansic",
|
||||
"unixdate",
|
||||
"rubydate",
|
||||
"rfc822",
|
||||
"rfc822z",
|
||||
"rfc850",
|
||||
"rfc1123",
|
||||
"rfc1123z",
|
||||
"rfc3339",
|
||||
"rfc3339nano",
|
||||
"kitchen",
|
||||
"stamp",
|
||||
"stampmilli",
|
||||
"stampmicro",
|
||||
"stampnano"
|
||||
],
|
||||
"default": "",
|
||||
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
|
||||
},
|
||||
"metricsMaxInMemory": {
|
||||
"type": "integer",
|
||||
"default": 1000,
|
||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||
},
|
||||
"startPort": {
|
||||
"type": "integer",
|
||||
"default": 5800,
|
||||
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
|
||||
},
|
||||
"includeAliasesInList": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
|
||||
},
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"models": {
|
||||
"type": "object",
|
||||
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"cmd"
|
||||
],
|
||||
"properties": {
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"cmd": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
|
||||
},
|
||||
"cmdStop": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 128,
|
||||
"description": "Display name for the model. Used in v1/models API response."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 1024,
|
||||
"description": "Description for the model. Used in v1/models API response."
|
||||
},
|
||||
"env": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
|
||||
},
|
||||
"default": [],
|
||||
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
|
||||
},
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"default": "http://localhost:${PORT}",
|
||||
"format": "uri",
|
||||
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
|
||||
},
|
||||
"aliases": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Alternative model names for this configuration. Must be unique globally."
|
||||
},
|
||||
"checkEndpoint": {
|
||||
"type": "string",
|
||||
"default": "/health",
|
||||
"pattern": "^/.*$|^none$",
|
||||
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
|
||||
},
|
||||
"ttl": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
|
||||
},
|
||||
"useModelName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Only stripParams is supported."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
|
||||
},
|
||||
"concurrencyLimit": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
|
||||
},
|
||||
"unlisted": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"on_startup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"preload": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [],
|
||||
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Actions to perform on startup. Only supported action is preload."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
}
|
||||
}
|
||||
}
|
||||
+146
-32
@@ -1,9 +1,20 @@
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly. Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
|
||||
# Usage notes:
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings with a default value, or noted as optional can be omitted.
|
||||
# - Settings that are marked required must be in your configuration file
|
||||
# - Settings noted as "required" must be in your configuration file
|
||||
# - Settings noted as "optional" can be omitted
|
||||
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
@@ -15,6 +26,14 @@ healthCheckTimeout: 500
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
@@ -27,26 +46,59 @@ metricsMaxInMemory: 1000
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# macros: sets a dictionary of string:string pairs
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - these are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - below are examples of the various settings a model can have:
|
||||
# - available model settings: env, cmd, cmdStop, proxy, aliases, checkEndpoint, ttl, unlisted
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
@@ -56,6 +108,8 @@ models:
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
@@ -92,49 +146,88 @@ models:
|
||||
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - use "none" to skip endpoint ready checking
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready (or fails)
|
||||
# - all requests wait until the endpoint is ready or fails
|
||||
# - use "none" to skip endpoint health checking
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after this many seconds
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
# useModelName: overrides the model name that is sent to upstream server
|
||||
# useModelName: override the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful when the upstream server expects a specific model name or format
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
filters:
|
||||
# strip_params: a comma separated list of parameters to remove from the request
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for preventing overriding of default server params by requests
|
||||
# - `model` parameter is never removed
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
strip_params: "temperature, top_p, top_k"
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
# - useful for limiting the number of active parallel requests a model can process
|
||||
# - must be set per model
|
||||
# - any number greater than 0 will override the internal default value of 10
|
||||
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: true or false
|
||||
# unlisted: boolean, true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models or /upstream lists
|
||||
# - unlisted models do not show up in /v1/models api requests
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker example:
|
||||
# container run times like Docker and Podman can also be used with a
|
||||
# a combination of cmd and cmdStop.
|
||||
# container runtimes like Docker and Podman can be used reliably with
|
||||
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name dockertest
|
||||
docker run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
@@ -142,24 +235,26 @@ models:
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - on POSIX systems: a SIGTERM is sent for graceful shutdown
|
||||
# - on Windows, taskkill is used
|
||||
# - processes are given 5 seconds to shutdown until they are forcefully killed
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
cmdStop: docker stop dockertest
|
||||
#
|
||||
# When empty, llama-swap has this default behaviour:
|
||||
# - on POSIX systems: a SIGTERM signal is sent
|
||||
# - on Windows, calls taskkill to stop the process
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# groups: a dictionary of group settings
|
||||
# - optional, default: empty dictionary
|
||||
# - provide advanced controls over model swapping behaviour.
|
||||
# - Using groups some models can be kept loaded indefinitely, while others are swapped out.
|
||||
# - model ids must be defined in the Models section
|
||||
# - provides advanced controls over model swapping behaviour
|
||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||
# - model IDs must be defined in the Models section
|
||||
# - a model can only be a member of one group
|
||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||
groups:
|
||||
# group1 is same as the default behaviour of llama-swap where only one model is allowed
|
||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||
# to run a time across the whole llama-swap instance
|
||||
"group1":
|
||||
# swap: controls the model swapping behaviour in within the group
|
||||
@@ -181,10 +276,13 @@ groups:
|
||||
- "qwen-unlisted"
|
||||
|
||||
# Example:
|
||||
# - in this group all the models can run at the same time
|
||||
# - when a different group loads all running models in this group are unloaded
|
||||
# - in group2 all models can run at the same time
|
||||
# - when a different group is loaded it causes all running models in this group to unload
|
||||
"group2":
|
||||
swap: false
|
||||
|
||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||
# - the models in group2 will be loaded but will not unload any other groups
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
@@ -207,3 +305,19 @@ groups:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
+46
-22
@@ -20,36 +20,60 @@ if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||
# variable, this permits testing with forked llama.cpp repositories
|
||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||
|
||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||
# to enable easy container builds on forked repos
|
||||
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the latest available
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
# cpu only containers just use the server tag
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
fi
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
|
||||
USER_UID=0
|
||||
USER_GID=0
|
||||
USER_HOME=/root
|
||||
|
||||
if [ "$CONTAINER_TYPE" == "non-root" ]; then
|
||||
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
|
||||
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
|
||||
USER_UID=10001
|
||||
USER_GID=10001
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -1,16 +1,40 @@
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=89
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
# Set default UID/GID arguments
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
ARG USER_HOME=/app
|
||||
|
||||
# Add user/group
|
||||
ENV HOME=$USER_HOME
|
||||
RUN if [ $UID -ne 0 ]; then \
|
||||
if [ $GID -ne 0 ]; then \
|
||||
groupadd --system --gid $GID app; \
|
||||
fi; \
|
||||
useradd --system --uid $UID --gid $GID \
|
||||
--home $USER_HOME app; \
|
||||
fi
|
||||
|
||||
# Handle paths
|
||||
RUN mkdir --parents $HOME /app
|
||||
RUN chown --recursive $UID:$GID $HOME /app
|
||||
|
||||
# Switch user
|
||||
USER $UID:$GID
|
||||
|
||||
WORKDIR /app
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
|
||||
|
||||
COPY config.example.yaml /app/config.yaml
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
|
||||
@@ -0,0 +1,386 @@
|
||||
# config.yaml
|
||||
|
||||
llama-swap is designed to be very simple: one binary, one configuration file.
|
||||
|
||||
## minimal viable config
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||
```
|
||||
|
||||
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
|
||||
model2:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
|
||||
model3:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
|
||||
```
|
||||
|
||||
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature.
|
||||
|
||||
## Advanced control with `cmd`
|
||||
|
||||
llama-swap is also about customizability. You can use any CLI flag available:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: | # support for multi-line
|
||||
llama-server --PORT ${PORT} -m /path/to/model.gguf
|
||||
--ctx-size 8192
|
||||
--jinja
|
||||
--cache-type-k q8_0
|
||||
--cache-type-v q8_0
|
||||
```
|
||||
|
||||
## Support for any OpenAI API compatible server
|
||||
|
||||
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"Q3-30B-CODER-VLLM":
|
||||
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
|
||||
# cmdStop provides a reliable way to stop containers
|
||||
cmdStop: docker stop vllm-coder
|
||||
cmd: |
|
||||
docker run --init --rm --name vllm-coder
|
||||
--runtime=nvidia --gpus '"device=2,3"'
|
||||
--shm-size=16g
|
||||
-v /mnt/nvme/vllm-cache:/root/.cache
|
||||
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
|
||||
vllm/vllm-openai:v0.10.0
|
||||
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
|
||||
--served-model-name "Q3-30B-CODER-VLLM"
|
||||
--enable-expert-parallel
|
||||
--swap-space 16
|
||||
--max-num-seqs 512
|
||||
--max-model-len 65536
|
||||
--max-seq-len-to-capture 65536
|
||||
--gpu-memory-utilization 0.9
|
||||
--tensor-parallel-size 2
|
||||
--trust-remote-code
|
||||
```
|
||||
|
||||
## Many more features..
|
||||
|
||||
llama-swap supports many more features to customize how you want to manage your environment.
|
||||
|
||||
| Feature | Description |
|
||||
| --------- | ---------------------------------------------- |
|
||||
| `ttl` | automatic unloading of models after a timeout |
|
||||
| `macros` | reusable snippets to use in configurations |
|
||||
| `groups` | run multiple models at a time |
|
||||
| `hooks` | event driven functionality |
|
||||
| `env` | define environment variables per model |
|
||||
| `aliases` | serve a model with different names |
|
||||
| `filters` | modify requests before sending to the upstream |
|
||||
| `...` | And many more tweaks |
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
|
||||
|
||||
```yaml
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly. Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
|
||||
# Usage notes:
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings noted as "required" must be in your configuration file
|
||||
# - Settings noted as "optional" can be omitted
|
||||
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
# - minimum value is 15 seconds, anything less will be set to this value
|
||||
healthCheckTimeout: 500
|
||||
|
||||
# logLevel: sets the logging value
|
||||
# - optional, default: info
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
# - using `|` allows for comments in the command, these will be parsed out
|
||||
# - macros can be used within cmd
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "llama 3.1 8B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A small but capable model used for quick testing"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
# - each value is a single string
|
||||
# - in the format: ENV_NAME=value
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||
|
||||
# proxy: the URL where llama-swap routes API requests
|
||||
# - optional, default: http://localhost:${PORT}
|
||||
# - if you used ${PORT} in cmd this can be omitted
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready or fails
|
||||
# - use "none" to skip endpoint health checking
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
# useModelName: override the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
# - useful for limiting the number of active parallel requests a model can process
|
||||
# - must be set per model
|
||||
# - any number greater than 0 will override the internal default value of 10
|
||||
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models api requests
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker example:
|
||||
# container runtimes like Docker and Podman can be used reliably with
|
||||
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
#
|
||||
# When empty, llama-swap has this default behaviour:
|
||||
# - on POSIX systems: a SIGTERM signal is sent
|
||||
# - on Windows, calls taskkill to stop the process
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# groups: a dictionary of group settings
|
||||
# - optional, default: empty dictionary
|
||||
# - provides advanced controls over model swapping behaviour
|
||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||
# - model IDs must be defined in the Models section
|
||||
# - a model can only be a member of one group
|
||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||
groups:
|
||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||
# to run a time across the whole llama-swap instance
|
||||
"group1":
|
||||
# swap: controls the model swapping behaviour in within the group
|
||||
# - optional, default: true
|
||||
# - true : only one model is allowed to run at a time
|
||||
# - false: all models can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive: controls how the group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: causes all other groups to unload when this group runs a model
|
||||
# - false: does not affect other groups
|
||||
exclusive: true
|
||||
|
||||
# members references the models defined above
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# Example:
|
||||
# - in group2 all models can run at the same time
|
||||
# - when a different group is loaded it causes all running models in this group to unload
|
||||
"group2":
|
||||
swap: false
|
||||
|
||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||
# - the models in group2 will be loaded but will not unload any other groups
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# Example:
|
||||
# - a persistent group, prevents other groups from unloading it
|
||||
"forever":
|
||||
# persistent: prevents over groups from unloading the models in this group
|
||||
# - optional, default: false
|
||||
# - does not affect individual model behaviour
|
||||
persistent: true
|
||||
|
||||
# set swap/exclusive to false to prevent swapping inside the group
|
||||
# and the unloading of other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
```
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/mostlygeek/llama-swap
|
||||
|
||||
go 1.23.0
|
||||
go 1.25.4
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
@@ -37,9 +37,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
+45
-9
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,7 +28,9 @@ var (
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
|
||||
@@ -38,13 +41,13 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
config, err := proxy.LoadConfig(*configPath)
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(config.Profiles) > 0 {
|
||||
if len(conf.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
@@ -54,6 +57,23 @@ func main() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -67,7 +87,7 @@ func main() {
|
||||
// Support for watching config and reloading when it changes
|
||||
reloadProxyManager := func() {
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||
return
|
||||
@@ -75,7 +95,9 @@ func main() {
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
srv.Handler = proxy.New(config)
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
@@ -85,12 +107,14 @@ func main() {
|
||||
})
|
||||
})
|
||||
} else {
|
||||
config, err = proxy.LoadConfig(*configPath)
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.Handler = proxy.New(config)
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +156,11 @@ func main() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||
// the change for k8s configmap
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
}
|
||||
|
||||
case err := <-watcher.Errors:
|
||||
@@ -161,9 +190,16 @@ func main() {
|
||||
}()
|
||||
|
||||
// Start server
|
||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
var err error
|
||||
if useTLS {
|
||||
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Fatal server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
-427
@@ -1,427 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros map[string]string `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// default configuration values
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
for macroName, macroValue := range config.Macros {
|
||||
if len(macroName) >= 64 {
|
||||
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
|
||||
}
|
||||
if !macroNameRegex.MatchString(macroName) {
|
||||
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
|
||||
}
|
||||
if len(macroValue) >= 1024 {
|
||||
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
|
||||
}
|
||||
switch macroName {
|
||||
case "PORT":
|
||||
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range config.Macros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
|
||||
}
|
||||
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
nextPortStr := strconv.Itoa(nextPort)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
}
|
||||
if _, exists := config.Macros[macroName]; !exists {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName, _ := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
@@ -0,0 +1,616 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type MacroList []MacroEntry
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("macros must be a mapping")
|
||||
}
|
||||
|
||||
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||
}
|
||||
|
||||
var val any
|
||||
if err := valueNode.Decode(&val); err != nil {
|
||||
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||
}
|
||||
|
||||
*ml = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a macro value by name
|
||||
func (ml MacroList) Get(name string) (any, bool) {
|
||||
for _, entry := range ml {
|
||||
if entry.Name == name {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||
func (ml MacroList) ToMap() map[string]any {
|
||||
result := make(map[string]any, len(ml))
|
||||
for _, entry := range ml {
|
||||
result[entry.Name] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// default configuration values
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Add global macros first
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (can override global)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Final pass: check if PORT macro is needed after macro expansion
|
||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||
// if it is required in either cmd or proxy keys
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort { // either has it
|
||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro and substitute it
|
||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
|
||||
// Substitute PORT in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
}
|
||||
// Reserved macros are always valid (they should have been substituted already)
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
// Any other macro is unknown
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unknown macros in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the proxy URL.
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf(
|
||||
"model %s: invalid proxy URL: %w", modelId, err,
|
||||
)
|
||||
}
|
||||
|
||||
// if sendLoadingState is nil, set it to the global config value
|
||||
// see #366
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState // copy it
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if len(v) >= 1024 {
|
||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||
}
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -58,6 +58,7 @@ models:
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
@@ -100,6 +101,9 @@ func TestConfig_LoadPosix(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
hooks:
|
||||
on_startup:
|
||||
preload: ["model1", "model2"]
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
@@ -157,42 +161,55 @@ groups:
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
OnStartup: HookOnStartup{
|
||||
Preload: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -1,4 +1,4 @@
|
||||
package proxy
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -65,18 +65,6 @@ models:
|
||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
@@ -207,30 +195,91 @@ macros:
|
||||
argOne: "--arg1"
|
||||
argTwo: "--arg2"
|
||||
autoPort: "--port ${PORT}"
|
||||
overriddenByModelMacro: failed
|
||||
|
||||
models:
|
||||
model1:
|
||||
macros:
|
||||
overriddenByModelMacro: success
|
||||
cmd: |
|
||||
${svr-path} ${argTwo}
|
||||
# the automatic ${PORT} is replaced
|
||||
${autoPort}
|
||||
${argOne}
|
||||
--arg3 three
|
||||
--overridden ${overriddenByModelMacro}
|
||||
cmdStop: |
|
||||
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
|
||||
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
|
||||
|
||||
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||
}
|
||||
|
||||
func TestConfig_MacroReservedNames(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "global macro named PORT",
|
||||
config: `
|
||||
macros:
|
||||
PORT: "1111"
|
||||
`,
|
||||
expectedError: "macro name 'PORT' is reserved",
|
||||
},
|
||||
{
|
||||
name: "global macro named MODEL_ID",
|
||||
config: `
|
||||
macros:
|
||||
MODEL_ID: model1
|
||||
`,
|
||||
expectedError: "macro name 'MODEL_ID' is reserved",
|
||||
},
|
||||
{
|
||||
name: "model macro named PORT",
|
||||
config: `
|
||||
models:
|
||||
model1:
|
||||
macros:
|
||||
PORT: 1111
|
||||
`,
|
||||
expectedError: "model model1: macro name 'PORT' is reserved",
|
||||
},
|
||||
|
||||
{
|
||||
name: "model macro named MODEL_ID",
|
||||
config: `
|
||||
models:
|
||||
model1:
|
||||
macros:
|
||||
MODEL_ID: model1
|
||||
`,
|
||||
expectedError: "model model1: macro name 'MODEL_ID' is reserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.config))
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, tt.expectedError, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -274,7 +323,7 @@ macros:
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
proxy: "http://localhost:${unknownMacro}"
|
||||
proxy: "http://${unknownMacro}:${PORT}"
|
||||
`,
|
||||
},
|
||||
{
|
||||
@@ -288,6 +337,20 @@ models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "unknown macro in filters.stripParams",
|
||||
field: "filters.stripParams",
|
||||
content: `
|
||||
startPort: 9990
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: "${svr-path} --port ${PORT}"
|
||||
filters:
|
||||
stripParams: "model,${unknownMacro}"
|
||||
`,
|
||||
},
|
||||
}
|
||||
@@ -295,38 +358,13 @@ models:
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||
}
|
||||
//t.Log(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
modelConfig, ok := config.Models["model1"]
|
||||
if !assert.True(t, ok) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// make sure `model` and enmpty strings are not in the list
|
||||
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripComments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -440,3 +478,286 @@ models:
|
||||
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
|
||||
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
|
||||
}
|
||||
|
||||
func TestConfig_MacroModelId(t *testing.T) {
|
||||
content := `
|
||||
startPort: 9000
|
||||
macros:
|
||||
"docker-llama": docker run --name ${MODEL_ID} -p ${PORT}:8080 docker_img
|
||||
"docker-stop": docker stop ${MODEL_ID}
|
||||
|
||||
models:
|
||||
model1:
|
||||
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
|
||||
|
||||
model2:
|
||||
cmd: ${docker-llama}
|
||||
cmdStop: ${docker-stop}
|
||||
|
||||
author/model:F16:
|
||||
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
|
||||
cmdStop: stop
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
||||
|
||||
dockerStopMacro, found := config.Macros.Get("docker-stop")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
|
||||
|
||||
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "docker run --name model2 -p 9002:8080 docker_img", strings.Join(sanitizedCmd2, " "))
|
||||
|
||||
sanitizedCmdStop, err := SanitizeCommand(config.Models["model2"].CmdStop)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "docker stop model2", strings.Join(sanitizedCmdStop, " "))
|
||||
|
||||
sanitizedCmd3, err := SanitizeCommand(config.Models["author/model:F16"].Cmd)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
|
||||
}
|
||||
|
||||
func TestConfig_TypedMacrosInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
PORT_NUM: 10001
|
||||
TEMP: 0.7
|
||||
ENABLED: true
|
||||
NAME: "llama model"
|
||||
CTX: 16384
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
metadata:
|
||||
port: ${PORT_NUM}
|
||||
temperature: ${TEMP}
|
||||
enabled: ${ENABLED}
|
||||
model_name: ${NAME}
|
||||
context: ${CTX}
|
||||
note: "Running on port ${PORT_NUM} with temp ${TEMP} and context ${CTX}"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
meta := config.Models["test-model"].Metadata
|
||||
assert.NotNil(t, meta)
|
||||
|
||||
// Verify direct substitution preserves types
|
||||
assert.Equal(t, 10001, meta["port"])
|
||||
assert.Equal(t, 0.7, meta["temperature"])
|
||||
assert.Equal(t, true, meta["enabled"])
|
||||
assert.Equal(t, "llama model", meta["model_name"])
|
||||
assert.Equal(t, 16384, meta["context"])
|
||||
|
||||
// Verify string interpolation converts to string
|
||||
assert.Equal(t, "Running on port 10001 with temp 0.7 and context 16384", meta["note"])
|
||||
}
|
||||
|
||||
func TestConfig_NestedStructuresInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
TEMP: 0.7
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
metadata:
|
||||
config:
|
||||
port: ${PORT}
|
||||
temperature: ${TEMP}
|
||||
tags: ["model:${MODEL_ID}", "port:${PORT}"]
|
||||
nested:
|
||||
deep:
|
||||
value: ${TEMP}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
meta := config.Models["test-model"].Metadata
|
||||
assert.NotNil(t, meta)
|
||||
|
||||
// Verify nested objects
|
||||
configMap := meta["config"].(map[string]any)
|
||||
assert.Equal(t, 10000, configMap["port"])
|
||||
assert.Equal(t, 0.7, configMap["temperature"])
|
||||
|
||||
// Verify arrays
|
||||
tags := meta["tags"].([]any)
|
||||
assert.Equal(t, "model:test-model", tags[0])
|
||||
assert.Equal(t, "port:10000", tags[1])
|
||||
|
||||
// Verify deeply nested structures
|
||||
nested := meta["nested"].(map[string]any)
|
||||
deep := nested["deep"].(map[string]any)
|
||||
assert.Equal(t, 0.7, deep["value"])
|
||||
}
|
||||
|
||||
func TestConfig_ModelLevelMacroPrecedenceInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
TEMP: 0.5
|
||||
GLOBAL_VAL: "global"
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
macros:
|
||||
TEMP: 0.9
|
||||
LOCAL_VAL: "local"
|
||||
metadata:
|
||||
temperature: ${TEMP}
|
||||
global: ${GLOBAL_VAL}
|
||||
local: ${LOCAL_VAL}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
meta := config.Models["test-model"].Metadata
|
||||
assert.NotNil(t, meta)
|
||||
|
||||
// Model-level macro should override global
|
||||
assert.Equal(t, 0.9, meta["temperature"])
|
||||
// Global macro should be accessible
|
||||
assert.Equal(t, "global", meta["global"])
|
||||
// Model-level macro should be accessible
|
||||
assert.Equal(t, "local", meta["local"])
|
||||
}
|
||||
|
||||
func TestConfig_UnknownMacroInMetadata(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
metadata:
|
||||
value: ${UNKNOWN_MACRO}
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "test-model")
|
||||
assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
|
||||
}
|
||||
|
||||
func TestConfig_InvalidMacroType(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
INVALID:
|
||||
nested: value
|
||||
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID")
|
||||
assert.Contains(t, err.Error(), "must be a scalar type")
|
||||
}
|
||||
|
||||
func TestConfig_MacroTypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "string macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
STR: "test"
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "int macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
NUM: 42
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "float macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
FLOAT: 3.14
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "bool macro",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
BOOL: true
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "array macro (invalid)",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
ARR: [1, 2, 3]
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "map macro (invalid)",
|
||||
yaml: `
|
||||
startPort: 10000
|
||||
macros:
|
||||
MAP:
|
||||
key: value
|
||||
models:
|
||||
test-model:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`,
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.yaml))
|
||||
if tt.shouldErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -55,6 +55,7 @@ models:
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
@@ -152,44 +153,52 @@ groups:
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: map[string]string{
|
||||
"svr-path": "path/to/server",
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test macro-in-macro basic substitution
|
||||
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${B}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test LIFO substitution order with 3+ macro levels
|
||||
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: load ${full}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test MODEL_ID in global macro used by model
|
||||
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: ${podman-llama} -m model.gguf
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||
}
|
||||
|
||||
// Test model macro overrides global macro in substitution
|
||||
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: echo ${msg}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test self-reference detection error
|
||||
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${recursive}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "recursive")
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${A}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
|
||||
// Macros: see #264
|
||||
// Model level macros take precedence over the global macros
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// Metadata: see #264
|
||||
// Arbitrary metadata that can be exposed through the API
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"stripParams"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal with the old field name for backwards compatibility
|
||||
if defaults.StripParams == "" {
|
||||
var legacy struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||
}
|
||||
defaults.StripParams = legacy.StripParams
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
# macros inserted and list is cleaned of duplicates and empty strings
|
||||
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
# check for strip_params (legacy field name) compatibility
|
||||
legacy:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
for modelId, modelConfig := range config.Models {
|
||||
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
// model has been removed
|
||||
// empty strings have been removed
|
||||
// duplicates have been removed
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||
content := `
|
||||
sendLoadingState: true
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
sendLoadingState: false
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, config.SendLoadingState)
|
||||
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||
}
|
||||
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -7,6 +7,7 @@ const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
@@ -48,3 +49,12 @@ type LogDataEvent struct {
|
||||
func (e LogDataEvent) Type() uint32 {
|
||||
return LogDataEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
@@ -9,13 +9,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -64,20 +66,18 @@ func getTestPort() int {
|
||||
return port
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, binaryPath, port, expectedMessage, port)
|
||||
`, simpleResponderPath, port, expectedMessage, port)
|
||||
|
||||
var cfg ModelConfig
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
+21
-6
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
@@ -32,6 +33,9 @@ type LogMonitor struct {
|
||||
// logging levels
|
||||
level LogLevel
|
||||
prefix string
|
||||
|
||||
// timestamps
|
||||
timeFormat string
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
@@ -40,11 +44,12 @@ func NewLogMonitor() *LogMonitor {
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
timeFormat: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,12 +111,22 @@ func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||
w.level = level
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.timeFormat = timeFormat
|
||||
}
|
||||
|
||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||
prefix := ""
|
||||
if w.prefix != "" {
|
||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
||||
timestamp := ""
|
||||
if w.timeFormat != "" {
|
||||
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||
|
||||
@@ -3,8 +3,10 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
@@ -84,3 +86,30 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_LogTimeFormat(t *testing.T) {
|
||||
// Create a new LogMonitor instance
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Enable timestamps
|
||||
lm.timeFormat = time.RFC3339
|
||||
|
||||
// Write the message to the LogMonitor
|
||||
lm.Info("Hello, World!")
|
||||
|
||||
// Get the history from the LogMonitor
|
||||
history := lm.GetHistory()
|
||||
|
||||
timestamp := ""
|
||||
fields := strings.Fields(string(history))
|
||||
if len(fields) > 0 {
|
||||
timestamp = fields[0]
|
||||
} else {
|
||||
t.Fatalf("Cannot extract string from history")
|
||||
}
|
||||
|
||||
_, err := time.Parse(time.RFC3339, timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot find timestamp: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
||||
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
c.Set("ls-real-model-name", realModelName)
|
||||
|
||||
writer := &MetricsResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
metricsRecorder: &MetricsRecorder{
|
||||
metricsMonitor: pm.metricsMonitor,
|
||||
realModelName: realModelName,
|
||||
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
startTime: time.Now(),
|
||||
},
|
||||
}
|
||||
c.Writer = writer
|
||||
c.Next()
|
||||
|
||||
rec := writer.metricsRecorder
|
||||
rec.processBody(writer.body)
|
||||
}
|
||||
}
|
||||
|
||||
type MetricsRecorder struct {
|
||||
metricsMonitor *MetricsMonitor
|
||||
realModelName string
|
||||
isStreaming bool
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// processBody handles response processing after request completes
|
||||
func (rec *MetricsRecorder) processBody(body []byte) {
|
||||
if rec.isStreaming {
|
||||
rec.processStreamingResponse(body)
|
||||
} else {
|
||||
rec.processNonStreamingResponse(body)
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
||||
usage := jsonData.Get("usage")
|
||||
if !usage.Exists() {
|
||||
return false
|
||||
}
|
||||
|
||||
// default values
|
||||
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
|
||||
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
|
||||
tokensPerSecond := -1.0
|
||||
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings := jsonData.Get("timings"); timings.Exists() {
|
||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||
}
|
||||
|
||||
rec.metricsMonitor.addMetrics(TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: rec.realModelName,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
||||
// Iterate **backwards** through the lines looking for the data payload with
|
||||
// usage data
|
||||
lines := bytes.Split(body, []byte("\n"))
|
||||
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
line := bytes.TrimSpace(lines[i])
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE payload always follows "data:"
|
||||
prefix := []byte("data:")
|
||||
if !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
// [DONE] line itself contains nothing of interest.
|
||||
continue
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
||||
return // short circuit if a metric was recorded
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON to extract usage information
|
||||
if gjson.ValidBytes(body) {
|
||||
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsResponseWriter captures the entire response for non-streaming
|
||||
type MetricsResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body []byte
|
||||
metricsRecorder *MetricsRecorder
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
w.body = append(w.body, b...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
+224
-15
@@ -1,11 +1,18 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
@@ -13,8 +20,10 @@ type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
}
|
||||
@@ -28,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
}
|
||||
|
||||
// MetricsMonitor parses llama-server output for token statistics
|
||||
type MetricsMonitor struct {
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
}
|
||||
|
||||
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||
maxMetrics := config.MetricsMaxInMemory
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
}
|
||||
|
||||
mp := &MetricsMonitor{
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
||||
mp := &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
}
|
||||
|
||||
@@ -50,7 +56,7 @@ func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
@@ -60,12 +66,11 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the current metrics
|
||||
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||
// getMetrics returns a copy of the current metrics
|
||||
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
@@ -74,9 +79,213 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetMetricsJSON returns metrics as JSON
|
||||
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
||||
// getMetricsJSON returns metrics as JSON
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
recorder := newBodyCopier(writer)
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
return nil
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics skipped, empty body")
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||
// Iterate **backwards** through the body looking for the data payload with
|
||||
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||
|
||||
// Start from the end of the body and scan backwards for newlines
|
||||
pos := len(body)
|
||||
for pos > 0 {
|
||||
// Find the previous newline (or start of body)
|
||||
lineStart := bytes.LastIndexByte(body[:pos], '\n')
|
||||
if lineStart == -1 {
|
||||
lineStart = 0
|
||||
} else {
|
||||
lineStart++ // Move past the newline
|
||||
}
|
||||
|
||||
line := bytes.TrimSpace(body[lineStart:pos])
|
||||
pos = lineStart - 1 // Move position before the newline for next iteration
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE payload always follows "data:"
|
||||
prefix := []byte("data:")
|
||||
if !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
// [DONE] line itself contains nothing of interest.
|
||||
continue
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
parsed := gjson.ParseBytes(data)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
return parseMetrics(modelID, start, usage, timings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||
// default values
|
||||
cachedTokens := -1 // unknown or missing data
|
||||
outputTokens := 0
|
||||
inputTokens := 0
|
||||
|
||||
// timings data
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
durationMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
if usage.Exists() {
|
||||
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||
// v1/chat/completions
|
||||
inputTokens = int(pt.Int())
|
||||
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||
// v1/messages
|
||||
inputTokens = int(it.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||
// v1/chat/completions
|
||||
outputTokens = int(ct.Int())
|
||||
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||
outputTokens = int(ot.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||
cachedTokens = int(ct.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings.Exists() {
|
||||
inputTokens = int(timings.Get("prompt_n").Int())
|
||||
outputTokens = int(timings.Get("predicted_n").Int())
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
durationMs = int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = int(cachedValue.Int())
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
CachedTokens: cachedTokens,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||
bodyBuffer := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: bodyBuffer,
|
||||
tee: io.MultiWriter(w, bodyBuffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if w.start.IsZero() {
|
||||
w.start = time.Now()
|
||||
}
|
||||
|
||||
// Single write operation that writes to both the response and buffer
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
@@ -0,0 +1,693 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
|
||||
mm.addMetrics(metric)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 0, metrics[0].ID)
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||
}
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 5, len(metrics))
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Equal(t, i, metrics[i].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 3)
|
||||
|
||||
// Add 5 metrics
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model",
|
||||
InputTokens: i,
|
||||
})
|
||||
}
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 3, len(metrics))
|
||||
|
||||
// Should keep the last 3 metrics (IDs 2, 3, 4)
|
||||
assert.Equal(t, 2, metrics[0].ID)
|
||||
assert.Equal(t, 3, metrics[1].ID)
|
||||
assert.Equal(t, 4, metrics[2].ID)
|
||||
})
|
||||
|
||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||
cancel := event.On(func(e TokenMetricsEvent) {
|
||||
receivedEvent <- e
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
|
||||
mm.addMetrics(metric)
|
||||
|
||||
select {
|
||||
case evt := <-receivedEvent:
|
||||
assert.Equal(t, 0, evt.Metrics.ID)
|
||||
assert.Equal(t, "test-model", evt.Metrics.Model)
|
||||
assert.Equal(t, 100, evt.Metrics.InputTokens)
|
||||
assert.Equal(t, 50, evt.Metrics.OutputTokens)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timeout waiting for event")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
metrics := mm.getMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||
|
||||
metrics1 := mm.getMetrics()
|
||||
metrics2 := mm.getMetrics()
|
||||
|
||||
// Verify we got copies
|
||||
assert.Equal(t, 2, len(metrics1))
|
||||
assert.Equal(t, 2, len(metrics2))
|
||||
|
||||
// Modify the returned slice shouldn't affect the original
|
||||
metrics1[0].Model = "modified"
|
||||
metrics3 := mm.getMetrics()
|
||||
assert.Equal(t, "model1", metrics3[0].Model)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
jsonData, err := mm.getMetricsJSON()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
|
||||
var metrics []TokenMetrics
|
||||
err = json.Unmarshal(jsonData, &metrics)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model1",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TokensPerSecond: 25.5,
|
||||
})
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model2",
|
||||
InputTokens: 200,
|
||||
OutputTokens: 100,
|
||||
TokensPerSecond: 30.0,
|
||||
})
|
||||
|
||||
jsonData, err := mm.getMetricsJSON()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var metrics []TokenMetrics
|
||||
err = json.Unmarshal(jsonData, &metrics)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(metrics))
|
||||
assert.Equal(t, "model1", metrics[0].Model)
|
||||
assert.Equal(t, "model2", metrics[1].Model)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50
|
||||
}
|
||||
}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("successful request with timings data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
"prompt_n": 100,
|
||||
"predicted_n": 50,
|
||||
"prompt_per_second": 150.5,
|
||||
"predicted_per_second": 25.5,
|
||||
"prompt_ms": 500.0,
|
||||
"predicted_ms": 1500.0,
|
||||
"cache_n": 20
|
||||
}
|
||||
}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 20, metrics[0].CachedTokens)
|
||||
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
|
||||
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
|
||||
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
|
||||
})
|
||||
|
||||
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||
|
||||
data: {"choices":[{"text":" World"}]}
|
||||
|
||||
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
// When timings data is present, it takes precedence
|
||||
assert.Equal(t, 10, metrics[0].InputTokens)
|
||||
assert.Equal(t, 20, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("error"))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("not valid json"))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
expectedErr := assert.AnError
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
return expectedErr
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{"result": "ok"}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||
t.Run("captures response body", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
copier := newBodyCopier(ginCtx.Writer)
|
||||
|
||||
testData := []byte("test response body")
|
||||
n, err := copier.Write(testData)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(testData), n)
|
||||
assert.Equal(t, testData, copier.body.Bytes())
|
||||
assert.Equal(t, string(testData), rec.Body.String())
|
||||
})
|
||||
|
||||
t.Run("sets start time on first write", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
copier := newBodyCopier(ginCtx.Writer)
|
||||
|
||||
assert.True(t, copier.StartTime().IsZero())
|
||||
|
||||
copier.Write([]byte("test"))
|
||||
|
||||
assert.False(t, copier.StartTime().IsZero())
|
||||
})
|
||||
|
||||
t.Run("preserves headers", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
copier := newBodyCopier(ginCtx.Writer)
|
||||
|
||||
copier.Header().Set("X-Test", "value")
|
||||
|
||||
assert.Equal(t, "value", rec.Header().Get("X-Test"))
|
||||
})
|
||||
|
||||
t.Run("preserves status code", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
copier := newBodyCopier(ginCtx.Writer)
|
||||
|
||||
copier.WriteHeader(http.StatusCreated)
|
||||
|
||||
// Gin's ResponseWriter tracks status internally
|
||||
assert.Equal(t, http.StatusCreated, copier.Status())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
metricsPerGoroutine := 100
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < metricsPerGoroutine; j++ {
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "test-model",
|
||||
InputTokens: id*1000 + j,
|
||||
OutputTokens: j,
|
||||
})
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
// Writer goroutine
|
||||
go func() {
|
||||
for i := 0; i < 50; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "test-model"})
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Multiple reader goroutines
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 20; j++ {
|
||||
_ = mm.getMetrics()
|
||||
_, _ = mm.getMetricsJSON()
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
<-done
|
||||
wg.Wait()
|
||||
|
||||
// Final check
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 50, len(metrics))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
// Timings should take precedence over usage
|
||||
responseBody := `{
|
||||
"usage": {
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 25
|
||||
},
|
||||
"timings": {
|
||||
"prompt_n": 100,
|
||||
"predicted_n": 50,
|
||||
"prompt_per_second": 150.5,
|
||||
"predicted_per_second": 25.5,
|
||||
"prompt_ms": 500.0,
|
||||
"predicted_ms": 1500.0
|
||||
}
|
||||
}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
// Should use timings values, not usage values
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
"prompt_n": 100,
|
||||
"predicted_n": 50,
|
||||
"prompt_per_second": 150.5,
|
||||
"predicted_per_second": 25.5,
|
||||
"prompt_ms": 500.0,
|
||||
"predicted_ms": 1500.0
|
||||
}
|
||||
}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
// Metrics should be found in the last data line before [DONE]
|
||||
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||
|
||||
data: {"choices":[{"text":"Second"}]}
|
||||
|
||||
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `data: not json
|
||||
|
||||
data: [DONE]
|
||||
|
||||
`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := ``
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
// Empty body should not trigger WrapHandler processing
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
CachedTokens: 100,
|
||||
InputTokens: 500,
|
||||
OutputTokens: 250,
|
||||
PromptPerSecond: 1200.5,
|
||||
TokensPerSecond: 45.8,
|
||||
DurationMs: 5000,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mm.addMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
CachedTokens: 100,
|
||||
InputTokens: 500,
|
||||
OutputTokens: 250,
|
||||
PromptPerSecond: 1200.5,
|
||||
TokensPerSecond: 45.8,
|
||||
DurationMs: 5000,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mm.addMetrics(metric)
|
||||
}
|
||||
}
|
||||
+385
-65
@@ -2,19 +2,23 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type ProcessState string
|
||||
@@ -37,11 +41,13 @@ const (
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
ID string
|
||||
config config.ModelConfig
|
||||
cmd *exec.Cmd
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
|
||||
// PR #155 called to cancel the upstream process
|
||||
cmdMutex sync.RWMutex
|
||||
cancelUpstream context.CancelFunc
|
||||
|
||||
// closed when command exits
|
||||
@@ -53,12 +59,14 @@ type Process struct {
|
||||
healthCheckTimeout int
|
||||
healthCheckLoopInterval time.Duration
|
||||
|
||||
lastRequestHandled time.Time
|
||||
lastRequestHandledMutex sync.RWMutex
|
||||
lastRequestHandled time.Time
|
||||
|
||||
stateMutex sync.RWMutex
|
||||
state ProcessState
|
||||
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequests sync.WaitGroup
|
||||
inFlightRequestsCount atomic.Int32
|
||||
|
||||
// used to block on multiple start() calls
|
||||
waitStarting sync.WaitGroup
|
||||
@@ -73,16 +81,35 @@ type Process struct {
|
||||
failedStartCount int
|
||||
}
|
||||
|
||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||
concurrentLimit := 10
|
||||
if config.ConcurrencyLimit > 0 {
|
||||
concurrentLimit = config.ConcurrencyLimit
|
||||
}
|
||||
|
||||
// Setup the reverse proxy.
|
||||
proxyURL, err := url.Parse(config.Proxy)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||
}
|
||||
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
reverseProxy: reverseProxy,
|
||||
cancelUpstream: nil,
|
||||
processLogger: processLogger,
|
||||
proxyLogger: proxyLogger,
|
||||
@@ -105,6 +132,20 @@ func (p *Process) LogMonitor() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||
p.lastRequestHandledMutex.Lock()
|
||||
defer p.lastRequestHandledMutex.Unlock()
|
||||
p.lastRequestHandled = t
|
||||
}
|
||||
|
||||
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||
func (p *Process) getLastRequestHandled() time.Time {
|
||||
p.lastRequestHandledMutex.RLock()
|
||||
defer p.lastRequestHandledMutex.RUnlock()
|
||||
return p.lastRequestHandled
|
||||
}
|
||||
|
||||
// custom error types for swapping state
|
||||
var (
|
||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||
@@ -128,6 +169,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
||||
}
|
||||
|
||||
p.state = newState
|
||||
|
||||
// Atomically increment waitStarting when entering StateStarting
|
||||
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||
if newState == StateStarting {
|
||||
p.waitStarting.Add(1)
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||
return p.state, nil
|
||||
@@ -156,6 +204,15 @@ func (p *Process) CurrentState() ProcessState {
|
||||
return p.state
|
||||
}
|
||||
|
||||
// forceState forces the process state to the new state with mutex protection.
|
||||
// This should only be used in exceptional cases where the normal state transition
|
||||
// validation via swapState() cannot be used.
|
||||
func (p *Process) forceState(newState ProcessState) {
|
||||
p.stateMutex.Lock()
|
||||
defer p.stateMutex.Unlock()
|
||||
p.state = newState
|
||||
}
|
||||
|
||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||
// it is a private method because starting is automatic but stopping can be called
|
||||
// at any time.
|
||||
@@ -189,7 +246,7 @@ func (p *Process) start() error {
|
||||
}
|
||||
}
|
||||
|
||||
p.waitStarting.Add(1)
|
||||
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||
defer p.waitStarting.Done()
|
||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||
|
||||
@@ -199,8 +256,12 @@ func (p *Process) start() error {
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
setProcAttributes(p.cmd)
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
p.cmdWaitChan = make(chan struct{})
|
||||
p.cmdMutex.Unlock()
|
||||
|
||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||
|
||||
@@ -210,7 +271,7 @@ func (p *Process) start() error {
|
||||
// Set process state to failed
|
||||
if err != nil {
|
||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||
p.state = StateStopped // force it into a stopped state
|
||||
p.forceState(StateStopped) // force it into a stopped state
|
||||
return fmt.Errorf(
|
||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||
strings.Join(args, " "), err, curState, swapErr,
|
||||
@@ -283,10 +344,12 @@ func (p *Process) start() error {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for all inflight requests to complete and ticker
|
||||
p.inFlightRequests.Wait()
|
||||
// skip the TTL check if there are inflight requests
|
||||
if p.inFlightRequestsCount.Load() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||
p.Stop()
|
||||
return
|
||||
@@ -342,7 +405,7 @@ func (p *Process) Shutdown() {
|
||||
|
||||
p.stopCommand()
|
||||
// just force it to this state since there is no recovery from shutdown
|
||||
p.state = StateShutdown
|
||||
p.forceState(StateShutdown)
|
||||
}
|
||||
|
||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||
@@ -353,18 +416,33 @@ func (p *Process) stopCommand() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
}()
|
||||
|
||||
if p.cancelUpstream == nil {
|
||||
p.cmdMutex.RLock()
|
||||
cancelUpstream := p.cancelUpstream
|
||||
cmdWaitChan := p.cmdWaitChan
|
||||
p.cmdMutex.RUnlock()
|
||||
|
||||
if cancelUpstream == nil {
|
||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
p.cancelUpstream()
|
||||
<-p.cmdWaitChan
|
||||
cancelUpstream()
|
||||
<-cmdWaitChan
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
// wait a short time for a tcp connection to be established
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).DialContext,
|
||||
},
|
||||
|
||||
// give a long time to respond to the health check endpoint
|
||||
// after the connection is established. See issue: 276
|
||||
Timeout: 5000 * time.Millisecond,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
@@ -387,6 +465,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if p.reverseProxy == nil {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
@@ -406,68 +490,75 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
p.inFlightRequests.Add(1)
|
||||
p.inFlightRequestsCount.Add(1)
|
||||
defer func() {
|
||||
p.lastRequestHandled = time.Now()
|
||||
p.setLastRequestHandled(time.Now())
|
||||
p.inFlightRequestsCount.Add(-1)
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// for #366
|
||||
// - extract streaming param from request context, should have been set by proxymanager
|
||||
var srw *statusResponseWriter
|
||||
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
// start a goroutine to stream loading status messages into the response writer
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
|
||||
// PR #417 (no support for anthropic v1/messages yet)
|
||||
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||
}
|
||||
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
cancelLoadCtx()
|
||||
if srw != nil {
|
||||
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||
// before closing the connection. Without this, the connection would close before
|
||||
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||
srw.waitForCompletion(100 * time.Millisecond)
|
||||
} else {
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
// should trigger srw to stop sending loading events ...
|
||||
cancelLoadCtx()
|
||||
|
||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||
if err == nil {
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||
// disconnects before the response is sent
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if r == http.ErrAbortHandler {
|
||||
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
|
||||
} else {
|
||||
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}()
|
||||
|
||||
if srw != nil {
|
||||
// Wait for the goroutine to finish writing its final messages
|
||||
const completionTimeout = 1 * time.Second
|
||||
if !srw.waitForCompletion(completionTimeout) {
|
||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||
}
|
||||
p.reverseProxy.ServeHTTP(srw, r)
|
||||
} else {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
@@ -503,13 +594,16 @@ func (p *Process) waitForCmd() {
|
||||
case StateStopping:
|
||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||
p.state = StateStopped
|
||||
p.forceState(StateStopped)
|
||||
}
|
||||
default:
|
||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||
p.state = StateStopped // force it to be in this state
|
||||
p.forceState(StateStopped) // force it to be in this state
|
||||
}
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
close(p.cmdWaitChan)
|
||||
p.cmdMutex.Unlock()
|
||||
}
|
||||
|
||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||
@@ -524,7 +618,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return err
|
||||
@@ -535,6 +629,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
@@ -550,3 +645,228 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
"Waking up the hamsters...",
|
||||
"Teaching the model manners...",
|
||||
"Convincing the GPU to participate...",
|
||||
"Loading weights (they're heavy)...",
|
||||
"Herding electrons...",
|
||||
"Compiling excuses for the delay...",
|
||||
"Downloading more RAM...",
|
||||
"Asking the model nicely to boot up...",
|
||||
"Bribing CUDA with cookies...",
|
||||
"Still loading (blame VRAM)...",
|
||||
"The model is fashionably late...",
|
||||
"Warming up those tensors...",
|
||||
"Making the neural net do push-ups...",
|
||||
"Your patience is appreciated (really)...",
|
||||
"Almost there (probably)...",
|
||||
"Loading like it's 1999...",
|
||||
"The model forgot where it put its keys...",
|
||||
"Quantum tunneling through layers...",
|
||||
"Negotiating with the PCIe bus...",
|
||||
"Defrosting frozen parameters...",
|
||||
"Teaching attention heads to focus...",
|
||||
"Running the matrix (slowly)...",
|
||||
"Untangling transformer blocks...",
|
||||
"Calibrating the flux capacitor...",
|
||||
"Spinning up the probability wheels...",
|
||||
"Waiting for the GPU to wake from its nap...",
|
||||
"Converting caffeine to compute...",
|
||||
"Allocating virtual patience...",
|
||||
"Performing arcane CUDA rituals...",
|
||||
"The model is stuck in traffic...",
|
||||
"Inflating embeddings...",
|
||||
"Summoning computational demons...",
|
||||
"Pleading with the OOM killer...",
|
||||
"Calculating the meaning of life (still at 42)...",
|
||||
"Training the training wheels...",
|
||||
"Optimizing the optimizer...",
|
||||
"Bootstrapping the bootstrapper...",
|
||||
"Loading loading screen...",
|
||||
"Processing processing logs...",
|
||||
"Buffering buffer overflow jokes...",
|
||||
"The model hit snooze...",
|
||||
"Debugging the debugger...",
|
||||
"Compiling the compiler...",
|
||||
"Parsing the parser (meta)...",
|
||||
"Tokenizing tokens...",
|
||||
"Encoding the encoder...",
|
||||
"Hashing hash browns...",
|
||||
"Forking spoons (not forks)...",
|
||||
"The model is contemplating existence...",
|
||||
"Transcending dimensional barriers...",
|
||||
"Invoking elder tensor gods...",
|
||||
"Unfurling probability clouds...",
|
||||
"Synchronizing parallel universes...",
|
||||
"The GPU is having second thoughts...",
|
||||
"Recalibrating reality matrices...",
|
||||
"Time is an illusion, loading doubly so...",
|
||||
"Convincing bits to flip themselves...",
|
||||
"The model is reading its own documentation...",
|
||||
}
|
||||
|
||||
type statusResponseWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
process *Process
|
||||
wg sync.WaitGroup // Track goroutine completion
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||
s := &statusResponseWriter{
|
||||
writer: w,
|
||||
process: p,
|
||||
start: time.Now(),
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||
s.WriteHeader(http.StatusOK) // send status code 200
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||
return s
|
||||
}
|
||||
|
||||
// statusUpdates sends status updates to the client while the model is loading
|
||||
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Recover from panics caused by client disconnection
|
||||
// Note: recover() only works within the same goroutine, so we need it here
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
// Create a shuffled copy of loadingRemarks
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
// Pick a random duration to send a remark
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Now()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if s.process.CurrentState() == StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's time for a snarky remark
|
||||
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||
lastRemarkTime = time.Now()
|
||||
// Pick a new random duration for the next remark
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendLine(line string) {
|
||||
s.sendData(line + "\n")
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendData(data string) {
|
||||
// Create the proper SSE JSON structure
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write SSE formatted data, panic if not able to write
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
// Add Flush method
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
+86
-12
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := ModelConfig{
|
||||
config := config.ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
@@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := ModelConfig{
|
||||
config := config.ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
@@ -402,7 +403,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
config := ModelConfig{
|
||||
conf := config.ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
@@ -410,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
@@ -435,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
||||
// Upstream may be killed mid-response.
|
||||
// Assert an incomplete or partial response.
|
||||
assert.NotEqual(t, "12345", w.Body.String())
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
@@ -450,15 +453,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
config.CmdStop = "kill -TERM ${PID}"
|
||||
conf.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
@@ -470,15 +473,15 @@ func TestProcess_StopCmd(t *testing.T) {
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := config
|
||||
configWEnv := conf
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
|
||||
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
@@ -491,3 +494,74 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||
// handled appropriately.
|
||||
//
|
||||
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||
// response is sent from some reason.
|
||||
//
|
||||
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
expectedMessage := "panic_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// Start the process
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// Create a custom ResponseWriter that simulates a client disconnect
|
||||
// by panicking when Write is called after headers are sent
|
||||
panicWriter := &panicOnWriteResponseWriter{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
shouldPanic: true,
|
||||
}
|
||||
|
||||
// Make a request that will trigger the panic
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||
|
||||
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||
// ProxyRequest should catch and handle this panic gracefully.
|
||||
process.ProxyRequest(panicWriter, req)
|
||||
|
||||
// If we get here, the panic was properly recovered in ProxyRequest
|
||||
// The process should still be in a ready state
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||
// to simulate a client disconnect after headers are sent
|
||||
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||
type panicOnWriteResponseWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
shouldPanic bool
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||
w.headerWritten = true
|
||||
w.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.shouldPanic && w.headerWritten {
|
||||
// Simulate the panic that httputil.ReverseProxy throws
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
+37
-2
@@ -5,12 +5,14 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
config config.Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
@@ -24,7 +26,7 @@ type ProcessGroup struct {
|
||||
lastUsedProcess string
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
@@ -60,10 +62,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
|
||||
// is there something already running?
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
|
||||
// wait for the request to the new model to be fully handled
|
||||
// and prevent race conditions see issue #277
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
pg.lastUsedProcess = modelID
|
||||
|
||||
// short circuit and exit
|
||||
pg.Unlock()
|
||||
return nil
|
||||
}
|
||||
pg.Unlock()
|
||||
}
|
||||
@@ -76,6 +88,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
+39
-20
@@ -4,21 +4,23 @@ import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
@@ -33,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
@@ -44,32 +46,49 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
// use the same listening so if a model is already running, it will fail
|
||||
// this is a way to test that swap isolation is working
|
||||
// properly when there are parallel requests made at the
|
||||
// same time.
|
||||
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2"}
|
||||
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(tests))
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
|
||||
// make sure only one process is in the running state
|
||||
count := 0
|
||||
for _, process := range pg.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
}(modelName)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
|
||||
+256
-41
@@ -8,12 +8,16 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -21,10 +25,12 @@ const (
|
||||
PROFILE_SPLIT_CHAR = ":"
|
||||
)
|
||||
|
||||
type proxyCtxKey string
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config Config
|
||||
config config.Config
|
||||
ginEngine *gin.Engine
|
||||
|
||||
// logging
|
||||
@@ -32,16 +38,21 @@ type ProxyManager struct {
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
metricsMonitor *MetricsMonitor
|
||||
metricsMonitor *metricsMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
// shutdown signaling
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
|
||||
// version info
|
||||
buildDate string
|
||||
commit string
|
||||
version string
|
||||
}
|
||||
|
||||
func New(config Config) *ProxyManager {
|
||||
func New(config config.Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
@@ -69,8 +80,39 @@ func New(config Config) *ProxyManager {
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
// see: https://go.dev/src/time/format.go
|
||||
timeFormats := map[string]string{
|
||||
"ansic": time.ANSIC,
|
||||
"unixdate": time.UnixDate,
|
||||
"rubydate": time.RubyDate,
|
||||
"rfc822": time.RFC822,
|
||||
"rfc822z": time.RFC822Z,
|
||||
"rfc850": time.RFC850,
|
||||
"rfc1123": time.RFC1123,
|
||||
"rfc1123z": time.RFC1123Z,
|
||||
"rfc3339": time.RFC3339,
|
||||
"rfc3339nano": time.RFC3339Nano,
|
||||
"kitchen": time.Kitchen,
|
||||
"stamp": time.Stamp,
|
||||
"stampmilli": time.StampMilli,
|
||||
"stampmicro": time.StampMicro,
|
||||
"stampnano": time.StampNano,
|
||||
}
|
||||
|
||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok {
|
||||
proxyLogger.SetLogTimeFormat(timeFormat)
|
||||
upstreamLogger.SetLogTimeFormat(timeFormat)
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
|
||||
var maxMetrics int
|
||||
if config.MetricsMaxInMemory <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
} else {
|
||||
maxMetrics = config.MetricsMaxInMemory
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
ginEngine: gin.New(),
|
||||
@@ -79,12 +121,16 @@ func New(config Config) *ProxyManager {
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: NewMetricsMonitor(&config),
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
|
||||
buildDate: "unknown",
|
||||
commit: "abcd1234",
|
||||
version: "0",
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
@@ -94,11 +140,48 @@ func New(config Config) *ProxyManager {
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
|
||||
// run any startup hooks
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) setupGinEngine() {
|
||||
|
||||
pm.ginEngine.Use(func(c *gin.Context) {
|
||||
|
||||
// don't log the Wake on Lan proxy health check
|
||||
if c.Request.URL.Path == "/wol-health" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
@@ -152,21 +235,30 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
c.Next()
|
||||
})
|
||||
|
||||
mm := MetricsMiddleware(pm)
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
|
||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
||||
// Support embeddings and reranking
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /reranking endpoint + aliases
|
||||
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /completion endpoint
|
||||
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
@@ -186,10 +278,17 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
// see cmd/wol-proxy/wol-proxy.go, not logged
|
||||
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||
@@ -312,23 +411,49 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
record := gin.H{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
newRecord := func(modelId string) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
data = append(data, newRecord(id))
|
||||
|
||||
data = append(data, record)
|
||||
// Include aliases
|
||||
if pm.config.IncludeAliasesInList {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by the "id" key
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
si, _ := data[i]["id"].(string)
|
||||
sj, _ := data[j]["id"].(string)
|
||||
return si < sj
|
||||
})
|
||||
|
||||
// Set CORS headers if origin exists
|
||||
if origin := c.GetHeader("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
@@ -342,34 +467,102 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
if requestedModel == "" {
|
||||
// split the upstream path by / and search for the model name
|
||||
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||
if len(parts) == 0 {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
||||
modelFound := false
|
||||
searchModelName := ""
|
||||
var modelName, remainingPath string
|
||||
for i, part := range parts {
|
||||
if parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if searchModelName == "" {
|
||||
searchModelName = part
|
||||
} else {
|
||||
searchModelName = searchModelName + "/" + parts[i]
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
modelName = real
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
modelFound = true
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
||||
originalPath := c.Request.URL.Path
|
||||
c.Request.URL.Path = remainingPath
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
|
||||
realModelName := c.GetString("ls-real-model-name") // Should be set in MetricsMiddleware
|
||||
if realModelName == "" {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "ls-real-model-name not set")
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -411,10 +604,24 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -574,3 +781,11 @@ func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
pm.buildDate = buildDate
|
||||
pm.commit = commit
|
||||
pm.version = version
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
@@ -23,8 +25,10 @@ func addApiHandlers(pm *ProxyManager) {
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,6 +104,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering SSE
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
sendBuffer := make(chan messageEnvelope, 25)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
@@ -132,7 +138,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics TokenMetrics) {
|
||||
sendMetrics := func(metrics []TokenMetrics) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
@@ -168,16 +174,14 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e TokenMetricsEvent) {
|
||||
sendMetrics(e.Metrics)
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
for _, metrics := range pm.metricsMonitor.GetMetrics() {
|
||||
sendMetrics(metrics)
|
||||
}
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -195,10 +199,40 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
||||
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"version": pm.version,
|
||||
"commit": pm.commit,
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,6 +28,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
|
||||
+531
-101
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -9,18 +10,47 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
|
||||
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
|
||||
// The tests can panic otherwise:
|
||||
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
|
||||
// See: https://github.com/gin-gonic/gin/issues/1815
|
||||
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
|
||||
type TestResponseRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
closeChannel chan bool
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||
return r.closeChannel
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) closeClient() {
|
||||
r.closeChannel <- true
|
||||
}
|
||||
|
||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||
return &TestResponseRecorder{
|
||||
httptest.NewRecorder(),
|
||||
make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
@@ -33,23 +63,22 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
@@ -71,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
t.Run(requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -87,14 +116,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||
// Test that a persistent group is not affected by the swapping behaviour of
|
||||
// other groups.
|
||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
// the forever group is persistent and should not be affected by model1
|
||||
"forever": {
|
||||
Swap: true,
|
||||
@@ -113,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||
for _, requestedModel := range tests {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -131,9 +160,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
@@ -156,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
@@ -194,9 +223,9 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := Config{
|
||||
config := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
@@ -209,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Add("Origin", "i-am-the-origin")
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.ServeHTTP(w, req)
|
||||
@@ -279,6 +308,199 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) {
|
||||
// Process config through LoadConfigFromReader to apply macro substitution
|
||||
configYaml := `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
startPort: 10000
|
||||
models:
|
||||
model1:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
macros:
|
||||
PORT_NUM: 10001
|
||||
TEMP: 0.7
|
||||
NAME: "llama"
|
||||
metadata:
|
||||
port: ${PORT_NUM}
|
||||
temperature: ${TEMP}
|
||||
enabled: true
|
||||
note: "Running on port ${PORT_NUM}"
|
||||
nested:
|
||||
value: ${TEMP}
|
||||
model2:
|
||||
cmd: /path/to/server -p ${PORT}
|
||||
`
|
||||
processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(processedConfig)
|
||||
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response struct {
|
||||
Data []map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, response.Data, 2)
|
||||
|
||||
// Find model1 and model2 in response
|
||||
var model1Data, model2Data map[string]any
|
||||
for _, model := range response.Data {
|
||||
if model["id"] == "model1" {
|
||||
model1Data = model
|
||||
} else if model["id"] == "model2" {
|
||||
model2Data = model
|
||||
}
|
||||
}
|
||||
|
||||
// Verify model1 has llamaswap_meta
|
||||
assert.NotNil(t, model1Data)
|
||||
meta, exists := model1Data["meta"]
|
||||
if !assert.True(t, exists, "model1 should have meta key") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
metaMap := meta.(map[string]any)
|
||||
|
||||
lsmeta, exists := metaMap["llamaswap"]
|
||||
if !assert.True(t, exists, "model1 should have meta.llamaswap key") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
lsmetamap := lsmeta.(map[string]any)
|
||||
|
||||
// Verify type preservation
|
||||
assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64
|
||||
assert.Equal(t, 0.7, lsmetamap["temperature"])
|
||||
assert.Equal(t, true, lsmetamap["enabled"])
|
||||
// Verify string interpolation
|
||||
assert.Equal(t, "Running on port 10001", lsmetamap["note"])
|
||||
// Verify nested structure
|
||||
nested := lsmetamap["nested"].(map[string]any)
|
||||
assert.Equal(t, 0.7, nested["value"])
|
||||
|
||||
// Verify model2 does NOT have llamaswap_meta
|
||||
assert.NotNil(t, model2Data)
|
||||
_, exists = model2Data["llamaswap_meta"]
|
||||
assert.False(t, exists, "model2 should not have llamaswap_meta")
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||
// Intentionally add models in non-sorted order and with an unlisted model
|
||||
config := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"zeta": getTestSimpleResponderConfig("zeta"),
|
||||
"alpha": getTestSimpleResponderConfig("alpha"),
|
||||
"beta": getTestSimpleResponderConfig("beta"),
|
||||
"hidden": func() config.ModelConfig {
|
||||
mc := getTestSimpleResponderConfig("hidden")
|
||||
mc.Unlisted = true
|
||||
return mc
|
||||
}(),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Request models list
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// We expect only the listed models in sorted order by id
|
||||
expectedOrder := []string{"alpha", "beta", "zeta"}
|
||||
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
|
||||
got := make([]string, 0, len(response.Data))
|
||||
for _, m := range response.Data {
|
||||
id, _ := m["id"].(string)
|
||||
got = append(got, id)
|
||||
}
|
||||
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
||||
// Configure alias
|
||||
config := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
IncludeAliasesInList: true,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": func() config.ModelConfig {
|
||||
mc := getTestSimpleResponderConfig("model1")
|
||||
mc.Name = "Model 1"
|
||||
mc.Aliases = []string{"alias1"}
|
||||
return mc
|
||||
}(),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Request models list
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// We expect both base id and alias
|
||||
var model1Data, alias1Data map[string]any
|
||||
for _, model := range response.Data {
|
||||
if model["id"] == "model1" {
|
||||
model1Data = model
|
||||
} else if model["id"] == "alias1" {
|
||||
alias1Data = model
|
||||
}
|
||||
}
|
||||
|
||||
// Verify model1 has name
|
||||
assert.NotNil(t, model1Data)
|
||||
_, exists := model1Data["name"]
|
||||
if !assert.True(t, exists, "model1 should have name key") {
|
||||
t.FailNow()
|
||||
}
|
||||
name1, ok := model1Data["name"].(string)
|
||||
assert.True(t, ok, "name1 should be a string")
|
||||
|
||||
// Verify alias1 has name
|
||||
assert.NotNil(t, alias1Data)
|
||||
_, exists = alias1Data["name"]
|
||||
if !assert.True(t, exists, "alias1 should have name key") {
|
||||
t.FailNow()
|
||||
}
|
||||
name2, ok := alias1Data["name"].(string)
|
||||
assert.True(t, ok, "name2 should be a string")
|
||||
|
||||
// Name keys should match
|
||||
assert.Equal(t, name1, name2)
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
@@ -290,15 +512,15 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||
model3Config.Proxy = "http://localhost:10003/"
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": model3Config,
|
||||
},
|
||||
LogLevel: "error",
|
||||
Groups: map[string]GroupConfig{
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"test": {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2", "model3"},
|
||||
@@ -316,7 +538,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
defer wg.Done()
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||
proxy.ServeHTTP(w, req)
|
||||
@@ -333,38 +555,92 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Unload(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(conf)
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||
req = httptest.NewRequest("GET", "/unload", nil)
|
||||
w = httptest.NewRecorder()
|
||||
w = CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
|
||||
// give it a bit of time to stop
|
||||
<-time.After(time.Millisecond * 250)
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
select {
|
||||
case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
||||
// good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for model1 to stop")
|
||||
}
|
||||
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
||||
const testGroupId = "testGroup"
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
testGroupId: {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// start both model
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
|
||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
|
||||
// good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for model1 to stop")
|
||||
}
|
||||
|
||||
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
|
||||
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Shared configuration
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
@@ -385,7 +661,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
t.Run("no models loaded", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/running", nil)
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -403,13 +679,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Load just a model.
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Simulate browser call for the `/running` endpoint.
|
||||
req = httptest.NewRequest("GET", "/running", nil)
|
||||
w = httptest.NewRecorder()
|
||||
w = CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
var response RunningResponse
|
||||
@@ -427,9 +703,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -461,7 +737,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
@@ -480,15 +756,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||
modelConfig.UseModelName = upstreamModelName
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
requestedModel := "model1"
|
||||
@@ -496,7 +772,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -530,7 +806,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
// Create the request with the multipart form data
|
||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Verify the response
|
||||
@@ -543,9 +819,9 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -598,7 +874,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
@@ -611,27 +887,40 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Upstream(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: %s -port ${PORT} -silent -respond model1
|
||||
aliases: [model-alias]
|
||||
`, getSimpleResponderPath())
|
||||
|
||||
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
t.Run("main model name", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
})
|
||||
|
||||
t.Run("model alias", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "model1", rec.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -642,7 +931,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -654,14 +943,14 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = ModelFilters{
|
||||
modelConfig.Filters = config.ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
}
|
||||
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
LogLevel: "error",
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": modelConfig,
|
||||
},
|
||||
})
|
||||
@@ -670,7 +959,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
@@ -686,10 +975,29 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
// t.Logf("%v", response)
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "OK", rec.Body.String())
|
||||
}
|
||||
|
||||
// Ensure the custom llama-server /completion endpoint proxies correctly
|
||||
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -698,33 +1006,78 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a non-streaming request
|
||||
reqBody := `{"model":"model1", "stream": false}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "model1")
|
||||
}
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
|
||||
func TestProxyManager_StartupHooks(t *testing.T) {
|
||||
|
||||
// using real YAML as the configuration has gotten more complex
|
||||
// is the right approach as LoadConfigFromReader() does a lot more
|
||||
// than parse YAML now. Eventually migrate all tests to use this approach
|
||||
configStr := strings.Replace(`
|
||||
logLevel: error
|
||||
hooks:
|
||||
on_startup:
|
||||
preload:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
preloadTestGroup:
|
||||
swap: false
|
||||
members:
|
||||
- model1
|
||||
- model2
|
||||
models:
|
||||
model1:
|
||||
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1
|
||||
model2:
|
||||
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2
|
||||
`, "${simpleresponderpath}", simpleResponderPath, -1)
|
||||
|
||||
// Create a test model configuration
|
||||
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
if !assert.NoError(t, err, "Invalid configuration") {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events
|
||||
|
||||
unsub := event.On(func(e ModelPreloadedEvent) {
|
||||
preloadChan <- e
|
||||
})
|
||||
|
||||
defer unsub()
|
||||
|
||||
// Create the proxy which should trigger preloading
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-preloadChan:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for models to preload")
|
||||
}
|
||||
}
|
||||
// make sure they are both loaded
|
||||
_, foundGroup := proxy.processGroups["preloadTestGroup"]
|
||||
if !assert.True(t, foundGroup, "preloadTestGroup should exist") {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState())
|
||||
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState())
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
@@ -733,25 +1086,102 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
endpoints := []string{
|
||||
"/api/events",
|
||||
"/logs/stream",
|
||||
"/logs/stream/proxy",
|
||||
"/logs/stream/upstream",
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
t.Run(endpoint, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", endpoint, nil)
|
||||
req = req.WithContext(ctx)
|
||||
rec := CreateTestResponseRecorder()
|
||||
|
||||
// Run handler in goroutine and wait for context timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
proxy.ServeHTTP(rec, req)
|
||||
}()
|
||||
|
||||
// Wait for either the handler to complete or context to timeout
|
||||
<-ctx.Done()
|
||||
|
||||
// At this point, the handler has either finished or been cancelled
|
||||
// Wait for the goroutine to fully exit before reading
|
||||
<-done
|
||||
|
||||
// Now it's safe to read from rec - no more concurrent writes
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"streaming-model": getTestSimpleResponderConfig("streaming-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a streaming request
|
||||
reqBody := `{"model":"model1", "stream": true}`
|
||||
reqBody := `{"model":"streaming-model"}`
|
||||
// simple-responder will return text/event-stream when stream=true is in the query
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
rec := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
proxy.ServeHTTP(rec, req)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
|
||||
return
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||
assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||
}
|
||||
|
||||
func TestProxyManager_ApiGetVersion(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
// Version test map
|
||||
versionTest := map[string]string{
|
||||
"build_date": "1970-01-01T00:00:00Z",
|
||||
"commit": "cc915ddb6f04a42d9cd1f524e1d46ec6ed069fdc",
|
||||
"version": "v001",
|
||||
}
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
proxy := New(config)
|
||||
proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"])
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/version", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Ensure json response
|
||||
assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))
|
||||
|
||||
// Check for attributes
|
||||
response := map[string]string{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
for key, value := range versionTest {
|
||||
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+193
-92
File diff suppressed because it is too large
Load Diff
+7
-6
@@ -4,21 +4,21 @@
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"start": "vite",
|
||||
"build": "tsc -b && vite build --emptyOutDir",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tailwindcss/vite": "^4.1.8",
|
||||
"@tanstack/react-query": "^5.80.6",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0",
|
||||
"react-router-dom": "^7.6.2",
|
||||
"tailwindcss": "^4.1.8"
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.0",
|
||||
"@tailwindcss/vite": "^4.1.8",
|
||||
"@types/react": "^19.1.2",
|
||||
"@types/react-dom": "^19.1.2",
|
||||
"@vitejs/plugin-react": "^4.4.1",
|
||||
@@ -26,8 +26,9 @@
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.19",
|
||||
"globals": "^16.0.0",
|
||||
"tailwindcss": "^4.1.8",
|
||||
"typescript": "~5.8.3",
|
||||
"typescript-eslint": "^8.30.1",
|
||||
"vite": "^6.3.5"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+25
-37
@@ -1,48 +1,36 @@
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate, NavLink } from "react-router-dom";
|
||||
import { useEffect } from "react";
|
||||
import { Navigate, Route, BrowserRouter as Router, Routes } from "react-router-dom";
|
||||
import { Header } from "./components/Header";
|
||||
import { useAPI } from "./contexts/APIProvider";
|
||||
import { useTheme } from "./contexts/ThemeProvider";
|
||||
import { APIProvider } from "./contexts/APIProvider";
|
||||
import ActivityPage from "./pages/Activity";
|
||||
import LogViewerPage from "./pages/LogViewer";
|
||||
import ModelPage from "./pages/Models";
|
||||
import ActivityPage from "./pages/Activity";
|
||||
|
||||
function App() {
|
||||
const theme = useTheme();
|
||||
const { setConnectionState } = useTheme();
|
||||
|
||||
const { connectionStatus } = useAPI();
|
||||
|
||||
// Synchronize the window.title connections state with the actual connection state
|
||||
useEffect(() => {
|
||||
setConnectionState(connectionStatus);
|
||||
}, [connectionStatus]);
|
||||
|
||||
return (
|
||||
<Router basename="/ui/">
|
||||
<APIProvider>
|
||||
<div>
|
||||
<nav className="bg-surface border-b border-border p-2 h-[75px]">
|
||||
<div className="flex items-center justify-between mx-auto px-4 h-full">
|
||||
<h1 className="flex items-center p-0">llama-swap</h1>
|
||||
<div className="flex items-center space-x-4">
|
||||
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Logs
|
||||
</NavLink>
|
||||
<div className="flex flex-col h-screen">
|
||||
<Header />
|
||||
|
||||
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Models
|
||||
</NavLink>
|
||||
|
||||
<NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||
Activity
|
||||
</NavLink>
|
||||
<button className="btn btn--sm" onClick={theme.toggleTheme}>
|
||||
{theme.isDarkMode ? "🌙" : "☀️"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<main className="mx-auto py-4 px-4">
|
||||
<Routes>
|
||||
<Route path="/" element={<LogViewerPage />} />
|
||||
<Route path="/models" element={<ModelPage />} />
|
||||
<Route path="/activity" element={<ActivityPage />} />
|
||||
<Route path="*" element={<Navigate to="/" replace />} />
|
||||
</Routes>
|
||||
</main>
|
||||
</div>
|
||||
</APIProvider>
|
||||
<main className="flex-1 overflow-auto p-4">
|
||||
<Routes>
|
||||
<Route path="/" element={<LogViewerPage />} />
|
||||
<Route path="/models" element={<ModelPage />} />
|
||||
<Route path="/activity" element={<ActivityPage />} />
|
||||
<Route path="*" element={<Navigate to="/" replace />} />
|
||||
</Routes>
|
||||
</main>
|
||||
</div>
|
||||
</Router>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { useMemo } from "react";
|
||||
|
||||
const ConnectionStatusIcon = () => {
|
||||
const { connectionStatus, versionInfo } = useAPI();
|
||||
|
||||
const eventStatusColor = useMemo(() => {
|
||||
switch (connectionStatus) {
|
||||
case "connected":
|
||||
return "bg-emerald-500";
|
||||
case "connecting":
|
||||
return "bg-amber-500";
|
||||
case "disconnected":
|
||||
default:
|
||||
return "bg-red-500";
|
||||
}
|
||||
}, [connectionStatus]);
|
||||
|
||||
return (
|
||||
<div className="flex items-center" title={`Event Stream: ${connectionStatus ?? 'unknown'}\nAPI Version: ${versionInfo?.version ?? 'unknown'}\nCommit Hash: ${versionInfo?.commit?.substring(0,7) ?? 'unknown'}\nBuild Date: ${versionInfo?.build_date ?? 'unknown'}`}>
|
||||
<span className={`inline-block w-3 h-3 rounded-full ${eventStatusColor} mr-2`}></span>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ConnectionStatusIcon;
|
||||
@@ -0,0 +1,56 @@
|
||||
import { useCallback } from "react";
|
||||
import { RiMoonFill, RiSunFill } from "react-icons/ri";
|
||||
import { NavLink, type NavLinkRenderProps } from "react-router-dom";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
import ConnectionStatusIcon from "./ConnectionStatus";
|
||||
|
||||
export function Header() {
|
||||
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
|
||||
const handleTitleChange = useCallback(
|
||||
(newTitle: string) => {
|
||||
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
|
||||
},
|
||||
[setAppTitle]
|
||||
);
|
||||
|
||||
const navLinkClass = ({ isActive }: NavLinkRenderProps) =>
|
||||
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
|
||||
|
||||
return (
|
||||
<header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
|
||||
{screenWidth !== "xs" && screenWidth !== "sm" && (
|
||||
<h1
|
||||
contentEditable
|
||||
suppressContentEditableWarning
|
||||
className="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||
onBlur={(e) => handleTitleChange(e.currentTarget.textContent || "(set title)")}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
handleTitleChange(e.currentTarget.textContent || "(set title)");
|
||||
e.currentTarget.blur();
|
||||
}
|
||||
}}
|
||||
>
|
||||
{appTitle}
|
||||
</h1>
|
||||
)}
|
||||
|
||||
<menu className="flex items-center gap-4">
|
||||
<NavLink to="/" className={navLinkClass} type="button">
|
||||
Logs
|
||||
</NavLink>
|
||||
<NavLink to="/models" className={navLinkClass} type="button">
|
||||
Models
|
||||
</NavLink>
|
||||
<NavLink to="/activity" className={navLinkClass} type="button">
|
||||
Activity
|
||||
</NavLink>
|
||||
<button className="" onClick={toggleTheme}>
|
||||
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
|
||||
</button>
|
||||
<ConnectionStatusIcon />
|
||||
</menu>
|
||||
</header>
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useRef, createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
|
||||
import { createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
|
||||
import type { ConnectionState } from "../lib/types";
|
||||
|
||||
type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
@@ -15,19 +16,24 @@ interface APIProviderType {
|
||||
models: Model[];
|
||||
listModels: () => Promise<Model[]>;
|
||||
unloadAllModels: () => Promise<void>;
|
||||
unloadSingleModel: (model: string) => Promise<void>;
|
||||
loadModel: (model: string) => Promise<void>;
|
||||
enableAPIEvents: (enabled: boolean) => void;
|
||||
proxyLogs: string;
|
||||
upstreamLogs: string;
|
||||
metrics: Metrics[];
|
||||
connectionStatus: ConnectionState;
|
||||
versionInfo: VersionInfo;
|
||||
}
|
||||
|
||||
interface Metrics {
|
||||
id: number;
|
||||
timestamp: string;
|
||||
model: string;
|
||||
cache_tokens: number;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
prompt_per_second: number;
|
||||
tokens_per_second: number;
|
||||
duration_ms: number;
|
||||
}
|
||||
@@ -36,22 +42,37 @@ interface LogData {
|
||||
source: "upstream" | "proxy";
|
||||
data: string;
|
||||
}
|
||||
|
||||
interface APIEventEnvelope {
|
||||
type: "modelStatus" | "logData" | "metrics";
|
||||
data: string;
|
||||
}
|
||||
|
||||
interface VersionInfo {
|
||||
build_date: string;
|
||||
commit: string;
|
||||
version: string;
|
||||
}
|
||||
|
||||
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||
type APIProviderProps = {
|
||||
children: ReactNode;
|
||||
autoStartAPIEvents?: boolean;
|
||||
};
|
||||
|
||||
let apiEventSource: EventSource | null = null;
|
||||
|
||||
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
|
||||
const [proxyLogs, setProxyLogs] = useState("");
|
||||
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||
const [metrics, setMetrics] = useState<Metrics[]>([]);
|
||||
const apiEventSource = useRef<EventSource | null>(null);
|
||||
const [connectionStatus, setConnectionState] = useState<ConnectionState>("disconnected");
|
||||
const [versionInfo, setVersionInfo] = useState<VersionInfo>({
|
||||
build_date: "unknown",
|
||||
commit: "unknown",
|
||||
version: "unknown"
|
||||
});
|
||||
//const apiEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
const [models, setModels] = useState<Model[]>([]);
|
||||
|
||||
@@ -64,8 +85,8 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
|
||||
const enableAPIEvents = useCallback((enabled: boolean) => {
|
||||
if (!enabled) {
|
||||
apiEventSource.current?.close();
|
||||
apiEventSource.current = null;
|
||||
apiEventSource?.close();
|
||||
apiEventSource = null;
|
||||
setMetrics([]);
|
||||
return;
|
||||
}
|
||||
@@ -74,15 +95,34 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
const initialDelay = 1000; // 1 second
|
||||
|
||||
const connect = () => {
|
||||
const eventSource = new EventSource("/api/events");
|
||||
apiEventSource?.close();
|
||||
apiEventSource = new EventSource("/api/events");
|
||||
|
||||
eventSource.onmessage = (e: MessageEvent) => {
|
||||
setConnectionState("connecting");
|
||||
|
||||
apiEventSource.onopen = () => {
|
||||
// clear everything out on connect to keep things in sync
|
||||
setProxyLogs("");
|
||||
setUpstreamLogs("");
|
||||
setMetrics([]); // clear metrics on reconnect
|
||||
setModels([]); // clear models on reconnect
|
||||
retryCount = 0;
|
||||
setConnectionState("connected");
|
||||
};
|
||||
|
||||
apiEventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||
switch (message.type) {
|
||||
case "modelStatus":
|
||||
{
|
||||
const models = JSON.parse(message.data) as Model[];
|
||||
|
||||
// sort models by name and id
|
||||
models.sort((a, b) => {
|
||||
return (a.name + a.id).localeCompare(b.name + b.id);
|
||||
});
|
||||
|
||||
setModels(models);
|
||||
}
|
||||
break;
|
||||
@@ -101,9 +141,9 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
|
||||
case "metrics":
|
||||
{
|
||||
const newMetric = JSON.parse(message.data) as Metrics;
|
||||
const newMetrics = JSON.parse(message.data) as Metrics[];
|
||||
setMetrics((prevMetrics) => {
|
||||
return [newMetric, ...prevMetrics];
|
||||
return [...newMetrics, ...prevMetrics];
|
||||
});
|
||||
}
|
||||
break;
|
||||
@@ -112,19 +152,39 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
console.error(e.data, err);
|
||||
}
|
||||
};
|
||||
eventSource.onerror = () => {
|
||||
eventSource.close();
|
||||
|
||||
apiEventSource.onerror = () => {
|
||||
apiEventSource?.close();
|
||||
retryCount++;
|
||||
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||
setConnectionState("disconnected");
|
||||
setTimeout(connect, delay);
|
||||
};
|
||||
|
||||
apiEventSource.current = eventSource;
|
||||
};
|
||||
|
||||
connect();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// fetch version
|
||||
const fetchVersion = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/version");
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data: VersionInfo = await response.json();
|
||||
setVersionInfo(data);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
};
|
||||
|
||||
if (connectionStatus === 'connected') {
|
||||
fetchVersion();
|
||||
}
|
||||
}, [connectionStatus]);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoStartAPIEvents) {
|
||||
enableAPIEvents(true);
|
||||
@@ -151,7 +211,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
|
||||
const unloadAllModels = useCallback(async () => {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload/`, {
|
||||
const response = await fetch(`/api/models/unload`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
@@ -163,6 +223,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
}
|
||||
}, []);
|
||||
|
||||
const unloadSingleModel = useCallback(async (model: string) => {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload/${model}`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to unload model: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to unload model", model, error);
|
||||
throw error;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const loadModel = useCallback(async (model: string) => {
|
||||
try {
|
||||
const response = await fetch(`/upstream/${model}/`, {
|
||||
@@ -182,13 +256,16 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
unloadSingleModel,
|
||||
loadModel,
|
||||
enableAPIEvents,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
metrics,
|
||||
connectionStatus,
|
||||
versionInfo,
|
||||
}),
|
||||
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
|
||||
[models, listModels, unloadAllModels, unloadSingleModel, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics, connectionStatus, versionInfo]
|
||||
);
|
||||
|
||||
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
import { createContext, useContext, useEffect, type ReactNode } from "react";
|
||||
import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import type { ConnectionState } from "../lib/types";
|
||||
|
||||
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
|
||||
type ThemeContextType = {
|
||||
isDarkMode: boolean;
|
||||
screenWidth: ScreenWidth;
|
||||
isNarrow: boolean;
|
||||
toggleTheme: () => void;
|
||||
|
||||
// for managing the window title and connection state information
|
||||
appTitle: string;
|
||||
setAppTitle: (title: string) => void;
|
||||
setConnectionState: (state: ConnectionState) => void;
|
||||
};
|
||||
|
||||
const ThemeContext = createContext<ThemeContextType | undefined>(undefined);
|
||||
@@ -13,15 +22,70 @@ type ThemeProviderProps = {
|
||||
};
|
||||
|
||||
export function ThemeProvider({ children }: ThemeProviderProps) {
|
||||
const [appTitle, setAppTitle] = usePersistentState("app-title", "llama-swap");
|
||||
const [connectionState, setConnectionState] = useState<ConnectionState>("disconnected");
|
||||
|
||||
/**
|
||||
* Set the document.title with informative information
|
||||
*/
|
||||
useEffect(() => {
|
||||
const connectionIcon = connectionState === "connecting" ? "🟡" : connectionState === "connected" ? "🟢" : "🔴";
|
||||
document.title = connectionIcon + " " + appTitle; // Set initial title
|
||||
}, [appTitle, connectionState]);
|
||||
|
||||
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
|
||||
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
|
||||
|
||||
// matches tailwind classes
|
||||
// https://tailwindcss.com/docs/responsive-design
|
||||
useEffect(() => {
|
||||
const checkInnerWidth = () => {
|
||||
const innerWidth = window.innerWidth;
|
||||
if (innerWidth < 640) {
|
||||
setScreenWidth("xs");
|
||||
} else if (innerWidth < 768) {
|
||||
setScreenWidth("sm");
|
||||
} else if (innerWidth < 1024) {
|
||||
setScreenWidth("md");
|
||||
} else if (innerWidth < 1280) {
|
||||
setScreenWidth("lg");
|
||||
} else if (innerWidth < 1536) {
|
||||
setScreenWidth("xl");
|
||||
} else {
|
||||
setScreenWidth("2xl");
|
||||
}
|
||||
};
|
||||
|
||||
checkInnerWidth();
|
||||
window.addEventListener("resize", checkInnerWidth);
|
||||
|
||||
return () => window.removeEventListener("resize", checkInnerWidth);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
|
||||
}, [isDarkMode]);
|
||||
|
||||
const toggleTheme = () => setIsDarkMode((prev) => !prev);
|
||||
const isNarrow = useMemo(() => {
|
||||
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
|
||||
}, [screenWidth]);
|
||||
|
||||
return <ThemeContext.Provider value={{ isDarkMode, toggleTheme }}>{children}</ThemeContext.Provider>;
|
||||
return (
|
||||
<ThemeContext.Provider
|
||||
value={{
|
||||
isDarkMode,
|
||||
toggleTheme,
|
||||
screenWidth,
|
||||
isNarrow,
|
||||
appTitle,
|
||||
setAppTitle,
|
||||
setConnectionState,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</ThemeContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useTheme(): ThemeContextType {
|
||||
|
||||
+10
-2
@@ -93,6 +93,14 @@
|
||||
@apply px-4;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
table th {
|
||||
@apply p-2 font-semibold;
|
||||
}
|
||||
table td {
|
||||
@apply p-2;
|
||||
}
|
||||
|
||||
/* Navigation Header */
|
||||
|
||||
.navlink {
|
||||
@@ -122,7 +130,7 @@
|
||||
|
||||
/* Status Badges */
|
||||
.status {
|
||||
@apply inline-block px-2 py-1 text-xs font-medium rounded-full;
|
||||
@apply inline-block px-2 py-1 text-xs font-medium rounded-lg;
|
||||
}
|
||||
|
||||
.status--ready {
|
||||
@@ -140,7 +148,7 @@
|
||||
|
||||
/* Buttons */
|
||||
.btn {
|
||||
@apply bg-surface p-2 px-4 text-sm rounded-full border border-2 transition-colors duration-200 border-btn-border;
|
||||
@apply bg-surface py-2 px-4 text-sm rounded-md border transition-colors duration-200 border-btn-border;
|
||||
}
|
||||
|
||||
.btn:hover {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
export type ConnectionState = "connected" | "connecting" | "disconnected";
|
||||
+4
-1
@@ -3,11 +3,14 @@ import { createRoot } from "react-dom/client";
|
||||
import "./index.css";
|
||||
import App from "./App.tsx";
|
||||
import { ThemeProvider } from "./contexts/ThemeProvider";
|
||||
import { APIProvider } from "./contexts/APIProvider";
|
||||
|
||||
createRoot(document.getElementById("root")!).render(
|
||||
<StrictMode>
|
||||
<ThemeProvider>
|
||||
<App />
|
||||
<APIProvider>
|
||||
<App />
|
||||
</APIProvider>
|
||||
</ThemeProvider>
|
||||
</StrictMode>
|
||||
);
|
||||
|
||||
+88
-45
@@ -1,10 +1,6 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { useMemo } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
|
||||
const formatTimestamp = (timestamp: string): string => {
|
||||
return new Date(timestamp).toLocaleString();
|
||||
};
|
||||
|
||||
const formatSpeed = (speed: number): string => {
|
||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||
};
|
||||
@@ -13,57 +9,80 @@ const formatDuration = (ms: number): string => {
|
||||
return (ms / 1000).toFixed(2) + "s";
|
||||
};
|
||||
|
||||
const ActivityPage = () => {
|
||||
const { metrics } = useAPI();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const formatRelativeTime = (timestamp: string): string => {
|
||||
const now = new Date();
|
||||
const date = new Date(timestamp);
|
||||
const diffInSeconds = Math.floor((now.getTime() - date.getTime()) / 1000);
|
||||
|
||||
useEffect(() => {
|
||||
if (metrics.length > 0) {
|
||||
setError(null);
|
||||
}
|
||||
}, [metrics]);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h1 className="text-2xl font-bold mb-4">Activity</h1>
|
||||
<div className="bg-red-50 border border-red-200 rounded-md p-4">
|
||||
<p className="text-red-800">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
// Handle future dates by returning "just now"
|
||||
if (diffInSeconds < 5) {
|
||||
return "now";
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h1 className="text-2xl font-bold mb-4">Activity</h1>
|
||||
if (diffInSeconds < 60) {
|
||||
return `${diffInSeconds}s ago`;
|
||||
}
|
||||
|
||||
{metrics.length === 0 ? (
|
||||
const diffInMinutes = Math.floor(diffInSeconds / 60);
|
||||
if (diffInMinutes < 60) {
|
||||
return `${diffInMinutes}m ago`;
|
||||
}
|
||||
|
||||
const diffInHours = Math.floor(diffInMinutes / 60);
|
||||
if (diffInHours < 24) {
|
||||
return `${diffInHours}h ago`;
|
||||
}
|
||||
|
||||
return "a while ago";
|
||||
};
|
||||
|
||||
const ActivityPage = () => {
|
||||
const { metrics } = useAPI();
|
||||
const sortedMetrics = useMemo(() => {
|
||||
return [...metrics].sort((a, b) => b.id - a.id);
|
||||
}, [metrics]);
|
||||
|
||||
return (
|
||||
<div className="p-2">
|
||||
<h1 className="text-2xl font-bold">Activity</h1>
|
||||
|
||||
{metrics.length === 0 && (
|
||||
<div className="text-center py-8">
|
||||
<p className="text-gray-600">No metrics data available</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
)}
|
||||
{metrics.length > 0 && (
|
||||
<div className="card overflow-auto">
|
||||
<table className="min-w-full divide-y">
|
||||
<thead>
|
||||
<tr>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
|
||||
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
|
||||
<thead className="border-gray-200 dark:border-white/10">
|
||||
<tr className="text-left text-xs uppercase tracking-wider">
|
||||
<th className="px-6 py-3">ID</th>
|
||||
<th className="px-6 py-3">Time</th>
|
||||
<th className="px-6 py-3">Model</th>
|
||||
<th className="px-6 py-3">
|
||||
Cached <Tooltip content="prompt tokens from cache" />
|
||||
</th>
|
||||
<th className="px-6 py-3">
|
||||
Prompt <Tooltip content="new prompt tokens processed" />
|
||||
</th>
|
||||
<th className="px-6 py-3">Generated</th>
|
||||
<th className="px-6 py-3">Prompt Processing</th>
|
||||
<th className="px-6 py-3">Generation Speed</th>
|
||||
<th className="px-6 py-3">Duration</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y">
|
||||
{metrics.map((metric, index) => (
|
||||
<tr key={`${metric.id}-${index}`}>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
|
||||
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
|
||||
{sortedMetrics.map((metric) => (
|
||||
<tr key={metric.id} className="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||
<td className="px-4 py-4">{metric.id + 1 /* un-zero index */}</td>
|
||||
<td className="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
||||
<td className="px-6 py-4">{metric.model}</td>
|
||||
<td className="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
|
||||
<td className="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
|
||||
<td className="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
|
||||
<td className="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
|
||||
<td className="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
@@ -74,4 +93,28 @@ const ActivityPage = () => {
|
||||
);
|
||||
};
|
||||
|
||||
interface TooltipProps {
|
||||
content: string;
|
||||
}
|
||||
|
||||
const Tooltip: React.FC<TooltipProps> = ({ content }) => {
|
||||
return (
|
||||
<div className="relative group inline-block">
|
||||
ⓘ
|
||||
<div
|
||||
className="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
|
||||
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
|
||||
opacity-0 group-hover:opacity-100 transition-opacity
|
||||
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
|
||||
>
|
||||
{content}
|
||||
<div
|
||||
className="absolute bottom-full left-1/2 transform -translate-x-1/2
|
||||
border-4 border-transparent border-b-gray-900"
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ActivityPage;
|
||||
|
||||
+71
-46
@@ -1,15 +1,38 @@
|
||||
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import {
|
||||
RiTextWrap,
|
||||
RiAlignJustify,
|
||||
RiFontSize,
|
||||
RiMenuSearchLine,
|
||||
RiMenuSearchFill,
|
||||
RiCloseCircleFill,
|
||||
} from "react-icons/ri";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
|
||||
const LogViewer = () => {
|
||||
const { proxyLogs, upstreamLogs } = useAPI();
|
||||
const { screenWidth } = useTheme();
|
||||
const direction = screenWidth === "xs" || screenWidth === "sm" ? "vertical" : "horizontal";
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-5" style={{ height: "calc(100vh - 125px)" }}>
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
<PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
|
||||
<Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||
</Panel>
|
||||
<PanelResizeHandle
|
||||
className={
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||
}
|
||||
/>
|
||||
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -17,17 +40,15 @@ interface LogPanelProps {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
const [isCollapsed, setIsCollapsed] = usePersistentState(`logPanel-${id}-isCollapsed`, false);
|
||||
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
|
||||
const [filterRegex, setFilterRegex] = useState("");
|
||||
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
|
||||
`logPanel-${id}-fontSize`,
|
||||
"normal"
|
||||
);
|
||||
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
|
||||
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
|
||||
|
||||
const textWrapClass = useMemo(() => {
|
||||
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
|
||||
@@ -48,6 +69,19 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const toggleWrapText = useCallback(() => {
|
||||
setTextWrap((prev) => !prev);
|
||||
}, []);
|
||||
|
||||
const toggleFilter = useCallback(() => {
|
||||
if (showFilter) {
|
||||
setShowFilter(false);
|
||||
setFilterRegex(""); // Clear filter when closing
|
||||
} else {
|
||||
setShowFilter(true);
|
||||
}
|
||||
}, [filterRegex, setFilterRegex, showFilter]);
|
||||
|
||||
const fontSizeClass = useMemo(() => {
|
||||
switch (fontSize) {
|
||||
case "xxs":
|
||||
@@ -81,56 +115,47 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
|
||||
}, [filteredLogs]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`bg-surface border border-border rounded-lg overflow-hidden flex flex-col ${
|
||||
!isCollapsed && "h-full"
|
||||
} ${className || ""}`}
|
||||
>
|
||||
<div className="p-4 border-b border-border bg-secondary">
|
||||
<div className="flex flex-col md:flex-row md:items-center md:justify-between gap-4">
|
||||
{/* Title - Always full width on mobile, normal on desktop */}
|
||||
<div className="w-full md:w-auto" onClick={() => setIsCollapsed(!isCollapsed)}>
|
||||
<h3 className="m-0 text-lg">{title}</h3>
|
||||
<div className="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full p-1">
|
||||
<div className="p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="m-0 text-lg p-0">{title}</h3>
|
||||
|
||||
<div className="flex gap-2 items-center">
|
||||
<button className="btn border-0" onClick={toggleFontSize}>
|
||||
<RiFontSize />
|
||||
</button>
|
||||
<button className="btn border-0" onClick={toggleWrapText}>
|
||||
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
|
||||
</button>
|
||||
<button className="btn border-0" onClick={toggleFilter}>
|
||||
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col sm:flex-row gap-4 w-full md:w-auto">
|
||||
{/* Sizing Buttons - Stacks vertically on mobile */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<button className="btn" onClick={toggleFontSize}>
|
||||
font: {fontSize}
|
||||
</button>
|
||||
<button className="btn" onClick={() => setTextWrap((prev) => !prev)}>
|
||||
{wrapText ? "wrap" : "wrap off"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
<div className="flex flex-1 min-w-0 gap-2">
|
||||
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||
{showFilter && (
|
||||
<div className="mt-2 w-full">
|
||||
<div className="flex gap-2 items-center w-full">
|
||||
<input
|
||||
type="text"
|
||||
className="flex-1 min-w-[120px] text-sm border p-2 rounded"
|
||||
className="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
|
||||
placeholder="Filter logs..."
|
||||
value={filterRegex}
|
||||
onChange={(e) => setFilterRegex(e.target.value)}
|
||||
/>
|
||||
<button className="btn" onClick={() => setFilterRegex("")}>
|
||||
Clear
|
||||
<button className="pl-2" onClick={() => setFilterRegex("")}>
|
||||
<RiCloseCircleFill size="24" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
|
||||
{!isCollapsed && (
|
||||
<div className="flex-1 bg-background font-mono text-sm p-3 overflow-hidden">
|
||||
<pre
|
||||
ref={preTagRef}
|
||||
className={`h-full p-4 overflow-y-auto whitespace-pre min-h-0 ${textWrapClass} ${fontSizeClass}`}
|
||||
>
|
||||
{filteredLogs}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
+457
-85
@@ -2,11 +2,47 @@ import { useState, useCallback, useMemo } from "react";
|
||||
import { useAPI } from "../contexts/APIProvider";
|
||||
import { LogPanel } from "./LogViewer";
|
||||
import { usePersistentState } from "../hooks/usePersistentState";
|
||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||
import { useTheme } from "../contexts/ThemeProvider";
|
||||
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
|
||||
|
||||
export default function ModelsPage() {
|
||||
const { models, unloadAllModels, loadModel, upstreamLogs, metrics } = useAPI();
|
||||
const { isNarrow } = useTheme();
|
||||
const direction = isNarrow ? "vertical" : "horizontal";
|
||||
const { upstreamLogs } = useAPI();
|
||||
|
||||
return (
|
||||
<PanelGroup direction={direction} className="gap-2" autoSaveId={"models-panel-group"}>
|
||||
<Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
|
||||
<ModelsPanel />
|
||||
</Panel>
|
||||
|
||||
<PanelResizeHandle
|
||||
className={
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||
}
|
||||
/>
|
||||
<Panel collapsible={true} defaultSize={50} minSize={0}>
|
||||
<div className="flex flex-col h-full space-y-4">
|
||||
{direction === "horizontal" && <StatsPanel />}
|
||||
<div className="flex-1 min-h-0">
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
</div>
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelsPanel() {
|
||||
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
|
||||
const { isNarrow } = useTheme();
|
||||
const [isUnloading, setIsUnloading] = useState(false);
|
||||
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
const filteredModels = useMemo(() => {
|
||||
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||
@@ -19,103 +55,439 @@ export default function ModelsPage() {
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
// at least give it a second to show the unloading message
|
||||
setTimeout(() => {
|
||||
setIsUnloading(false);
|
||||
}, 1000);
|
||||
}
|
||||
}, []);
|
||||
}, [unloadAllModels]);
|
||||
|
||||
const [totalRequests, totalTokens, avgTokensPerSecond] = useMemo(() => {
|
||||
const totalRequests = metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return [0, 0, 0];
|
||||
}
|
||||
const totalTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
|
||||
return [totalRequests, totalTokens, avgTokensPerSecond];
|
||||
}, [metrics]);
|
||||
const toggleIdorName = useCallback(() => {
|
||||
setShowIdorName((prev) => (prev === "name" ? "id" : "name"));
|
||||
}, [showIdorName]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-col md:flex-row gap-4">
|
||||
{/* Left Column */}
|
||||
<div className="w-full md:w-1/2 flex items-top">
|
||||
<div className="card w-full">
|
||||
<h2 className="">Models</h2>
|
||||
<div className="flex justify-between">
|
||||
<button className="btn" onClick={() => setShowUnlisted(!showUnlisted)} style={{ lineHeight: "1.2" }}>
|
||||
{showUnlisted ? "🟢 unlisted" : "⚫️ unlisted"}
|
||||
<div className="card h-full flex flex-col">
|
||||
<div className="shrink-0">
|
||||
<div className="flex justify-between items-baseline">
|
||||
<h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
|
||||
{isNarrow && (
|
||||
<div className="relative">
|
||||
<button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
|
||||
<RiMenuFill size="20" />
|
||||
</button>
|
||||
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
{isUnloading ? "Stopping ..." : "Stop All"}
|
||||
{menuOpen && (
|
||||
<div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||
<button
|
||||
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onClick={() => {
|
||||
toggleIdorName();
|
||||
setMenuOpen(false);
|
||||
}}
|
||||
>
|
||||
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
|
||||
</button>
|
||||
<button
|
||||
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onClick={() => {
|
||||
setShowUnlisted(!showUnlisted);
|
||||
setMenuOpen(false);
|
||||
}}
|
||||
>
|
||||
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
|
||||
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
|
||||
</button>
|
||||
<button
|
||||
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onClick={() => {
|
||||
handleUnloadAllModels();
|
||||
setMenuOpen(false);
|
||||
}}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{!isNarrow && (
|
||||
<div className="flex justify-between">
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={toggleIdorName}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||
style={{ lineHeight: "1.2" }}
|
||||
>
|
||||
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<table className="w-full mt-4">
|
||||
<thead>
|
||||
<tr className="border-b border-primary">
|
||||
<th className="text-left p-2">Name</th>
|
||||
<th className="text-left p-2"></th>
|
||||
<th className="text-left p-2">State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredModels.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||
<td className="p-2">
|
||||
<a href={`/upstream/${model.id}/`} className="underline" target="_blank">
|
||||
{model.name !== "" ? model.name : model.id}
|
||||
</a>
|
||||
{model.description != "" && (
|
||||
<p>
|
||||
<em>{model.description}</em>
|
||||
</p>
|
||||
)}
|
||||
</td>
|
||||
<td className="p-2 w-[50px]">
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
disabled={model.state !== "stopped"}
|
||||
onClick={() => loadModel(model.id)}
|
||||
>
|
||||
Load
|
||||
</button>
|
||||
</td>
|
||||
<td className="p-2 w-[75px]">
|
||||
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
<button
|
||||
className="btn text-base flex items-center gap-2"
|
||||
onClick={handleUnloadAllModels}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right Column */}
|
||||
<div className="w-full md:w-1/2 flex flex-col" style={{ height: "calc(100vh - 125px)" }}>
|
||||
<div className="card mb-4 min-h-[225px]">
|
||||
<h2>Chat Activity</h2>
|
||||
<table className="w-full border border-gray-200">
|
||||
<tbody>
|
||||
<tr className="border-b border-gray-200">
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Requests</td>
|
||||
<td className="py-2 px-4 text-right">{totalRequests}</td>
|
||||
</tr>
|
||||
<tr className="border-b border-gray-200">
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Total Tokens Generated</td>
|
||||
<td className="py-2 px-4 text-right">{totalTokens}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td className="py-2 px-4 font-medium border-r border-gray-200">Average Tokens/Second</td>
|
||||
<td className="py-2 px-4 text-right">{avgTokensPerSecond}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<table className="w-full">
|
||||
<thead className="sticky top-0 bg-card z-10">
|
||||
<tr className="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th>{showIdorName === "id" ? "Model ID" : "Name"}</th>
|
||||
<th></th>
|
||||
<th>State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredModels.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td className={`${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||
<a href={`/upstream/${model.id}/`} className="font-semibold" target="_blank">
|
||||
{showIdorName === "id" ? model.id : model.name !== "" ? model.name : model.id}
|
||||
</a>
|
||||
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||
</div>
|
||||
{!!model.description && (
|
||||
<p className={model.unlisted ? "text-opacity-70" : ""}>
|
||||
<em>{model.description}</em>
|
||||
</p>
|
||||
)}
|
||||
</td>
|
||||
<td className="w-12">
|
||||
{model.state === "stopped" ? (
|
||||
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
|
||||
Load
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn--sm"
|
||||
onClick={() => unloadSingleModel(model.id)}
|
||||
disabled={model.state !== "ready"}
|
||||
>
|
||||
Unload
|
||||
</button>
|
||||
)}
|
||||
</td>
|
||||
<td className="w-20">
|
||||
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
function TokenHistogram({ data }: { data: HistogramData }) {
|
||||
const { bins, min, max, p50, p95, p99 } = data;
|
||||
const maxCount = Math.max(...bins);
|
||||
|
||||
const height = 120;
|
||||
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
|
||||
|
||||
// Use viewBox for responsive sizing
|
||||
const viewBoxWidth = 600;
|
||||
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||
const chartHeight = height - padding.top - padding.bottom;
|
||||
|
||||
const barWidth = chartWidth / bins.length;
|
||||
const range = max - min;
|
||||
|
||||
// Calculate x position for a given value
|
||||
const getXPosition = (value: number) => {
|
||||
return padding.left + ((value - min) / range) * chartWidth;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-2 w-full">
|
||||
<svg
|
||||
viewBox={`0 0 ${viewBoxWidth} ${height}`}
|
||||
className="w-full h-auto"
|
||||
preserveAspectRatio="xMidYMid meet"
|
||||
>
|
||||
{/* Y-axis */}
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={padding.top}
|
||||
x2={padding.left}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
strokeWidth="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
{/* X-axis */}
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={height - padding.bottom}
|
||||
x2={viewBoxWidth - padding.right}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
strokeWidth="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
{/* Histogram bars */}
|
||||
{bins.map((count, i) => {
|
||||
const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0;
|
||||
const x = padding.left + i * barWidth;
|
||||
const y = height - padding.bottom - barHeight;
|
||||
const binStart = min + i * data.binSize;
|
||||
const binEnd = binStart + data.binSize;
|
||||
|
||||
return (
|
||||
<g key={i}>
|
||||
<rect
|
||||
x={x}
|
||||
y={y}
|
||||
width={Math.max(barWidth - 1, 1)}
|
||||
height={barHeight}
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
className="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
|
||||
/>
|
||||
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Percentile lines */}
|
||||
<line
|
||||
x1={getXPosition(p50)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(p50)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeDasharray="4 2"
|
||||
opacity="0.7"
|
||||
className="text-gray-600 dark:text-gray-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(p95)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(p95)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeDasharray="4 2"
|
||||
opacity="0.7"
|
||||
className="text-orange-500 dark:text-orange-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(p99)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(p99)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeDasharray="4 2"
|
||||
opacity="0.7"
|
||||
className="text-green-500 dark:text-green-400"
|
||||
/>
|
||||
|
||||
{/* X-axis labels */}
|
||||
<text
|
||||
x={padding.left}
|
||||
y={height - 5}
|
||||
fontSize="10"
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
textAnchor="start"
|
||||
>
|
||||
{min.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<text
|
||||
x={viewBoxWidth - padding.right}
|
||||
y={height - 5}
|
||||
fontSize="10"
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
textAnchor="end"
|
||||
>
|
||||
{max.toFixed(1)}
|
||||
</text>
|
||||
|
||||
{/* X-axis label */}
|
||||
<text
|
||||
x={padding.left + chartWidth / 2}
|
||||
y={height - 2}
|
||||
fontSize="10"
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
textAnchor="middle"
|
||||
>
|
||||
Tokens/Second Distribution
|
||||
</text>
|
||||
</svg>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function StatsPanel() {
|
||||
const { metrics } = useAPI();
|
||||
|
||||
const [totalRequests, totalInputTokens, totalOutputTokens, tokenStats, histogramData] = useMemo(() => {
|
||||
const totalRequests = metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return [0, 0, 0, { p99: 0, p95: 0, p50: 0 }, null];
|
||||
}
|
||||
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
|
||||
// Calculate token statistics using output_tokens and duration_ms
|
||||
// Filter out metrics with invalid duration or output tokens
|
||||
const validMetrics = metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
|
||||
if (validMetrics.length === 0) {
|
||||
return [totalRequests, totalInputTokens, totalOutputTokens, { p99: 0, p95: 0, p50: 0 }, null];
|
||||
}
|
||||
|
||||
// Calculate tokens/second for each valid metric
|
||||
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
|
||||
|
||||
// Sort for percentile calculation
|
||||
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
|
||||
|
||||
// Calculate percentiles - showing speed thresholds where X% of requests are SLOWER (below)
|
||||
// P99: 99% of requests are slower than this speed (99th percentile - fast requests)
|
||||
// P95: 95% of requests are slower than this speed (95th percentile)
|
||||
// P50: 50% of requests are slower than this speed (median)
|
||||
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
|
||||
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
|
||||
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
|
||||
|
||||
// Create histogram data
|
||||
const min = Math.min(...tokensPerSecond);
|
||||
const max = Math.max(...tokensPerSecond);
|
||||
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5))); // Adaptive bin count
|
||||
const binSize = (max - min) / binCount;
|
||||
|
||||
const bins = Array(binCount).fill(0);
|
||||
tokensPerSecond.forEach((value) => {
|
||||
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
||||
bins[binIndex]++;
|
||||
});
|
||||
|
||||
const histogramData = {
|
||||
bins,
|
||||
min,
|
||||
max,
|
||||
binSize,
|
||||
p99,
|
||||
p95,
|
||||
p50,
|
||||
};
|
||||
|
||||
return [
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
{
|
||||
p99: p99.toFixed(2),
|
||||
p95: p95.toFixed(2),
|
||||
p50: p50.toFixed(2),
|
||||
},
|
||||
histogramData,
|
||||
];
|
||||
}, [metrics]);
|
||||
|
||||
const nf = new Intl.NumberFormat();
|
||||
|
||||
return (
|
||||
<div className="card">
|
||||
<div className="rounded-lg overflow-hidden border border-card-border-inner">
|
||||
<table className="min-w-full divide-y divide-card-border-inner">
|
||||
<thead className="bg-secondary">
|
||||
<tr>
|
||||
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">
|
||||
Requests
|
||||
</th>
|
||||
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Processed
|
||||
</th>
|
||||
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Generated
|
||||
</th>
|
||||
<th className="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Token Stats (tokens/sec)
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
||||
<tbody className="bg-surface divide-y divide-card-border-inner">
|
||||
<tr className="hover:bg-secondary">
|
||||
<td className="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">{totalRequests}</td>
|
||||
|
||||
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium">{nf.format(totalInputTokens)}</span>
|
||||
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td className="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium">{nf.format(totalOutputTokens)}</span>
|
||||
<span className="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td className="px-4 py-4 border-l border-gray-200 dark:border-white/10">
|
||||
<div className="space-y-3">
|
||||
<div className="grid grid-cols-3 gap-2 items-center">
|
||||
<div className="text-center">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">P50</div>
|
||||
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{tokenStats.p50}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="text-center">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">P95</div>
|
||||
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{tokenStats.p95}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="text-center">
|
||||
<div className="text-xs text-gray-500 dark:text-gray-400">P99</div>
|
||||
<div className="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{tokenStats.p99}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{histogramData && <TokenHistogram data={histogramData} />}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -15,6 +15,7 @@ export default defineConfig({
|
||||
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||
"/logs": "http://localhost:8080",
|
||||
"/upstream": "http://localhost:8080",
|
||||
"/unload": "http://localhost:8080",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user