Compare commits

...

25 Commits

Author SHA1 Message Date
Benson Wong fd3c28ffc5 Refactor Activity Page (#710)
- inference handles to store an activity record for all inference endpoints
- add path, status code, and content type to Activities page
- toggle on/off columns no Activities page 
- add configurable capture level for inference endpoints so large binary blobs are not stored in memory
- store captures in compressed binary format
2026-04-28 20:33:03 -07:00
Quentin Machu a846c4f18c config: remove hard cap on macro length (#718)
Remove macro value limit of 1024 characters
2026-04-28 13:32:54 -07:00
Marcus 5bae33a769 ui-svelte: default theme to user preferred color scheme (#712)
Simple, if not set is localStorage use whatever the user's preferred
color scheme is to start.
2026-04-27 06:44:22 -07:00
Benson Wong 8f4ff01f93 ui-svelte: make it easier to toggle panels in logs view 2026-04-26 22:12:43 -07:00
Benson Wong e8d4384cd2 ui-svelte: support reasoning and reasoning_content (#708)
Support `reasoning` v1/chat/completion delta that vLLM uses.
2026-04-26 13:11:48 -07:00
Benson Wong ce28485be2 ui-svelte: add prompt processing histogram (#705)
Activities page shows histograms for prompt processing and token generation times. 

Fix: #691
Fix: #703
2026-04-25 16:13:07 -07:00
Damir 3cd7837b1f fix: support architecture-specific download URLs in install script (#698)
Just a small fix to include proper llama-swap binary when building the
arm64 architecture.
2026-04-23 18:05:33 -07:00
Benson Wong 0b31ccacc1 ui-svelte: fix histogram calculation (#695)
- Fix the histogram calculation to use server provided generation
tokens/second.
- Move histogram to Activities page where it can exist with the rest of
the token metrics

Fixes #681
2026-04-22 23:42:39 -07:00
Bryan Gahagan 5938dbee8f Push unified docker images on scheduled runs (#694)
Fixes #693
2026-04-22 20:46:51 -07:00
Benson Wong 66639e83f7 proxy: replace fsnotify with stat-poll watcher and add SIGHUP reload (#685)
The fsnotify-based config watcher does not work reliably when the config
file is bind-mounted into a Docker container as an individual file, and
mishandles k8s ConfigMap projections (atomically swapped symlinks).
Replace it with a small os.Stat-polling watcher and add SIGHUP as an
explicit reload signal.

- new proxy/configwatcher package: 2s os.Stat poller, follows symlinks,
  fires on mtime/size change and on missing -> present transitions
- SIGHUP triggers reload unconditionally (works without --watch-config)
  via the same ConfigFileChangedEvent pipeline so the UI sees identical
  state transitions
- watcher goroutine now exits cleanly on shutdown via a context
- drop github.com/fsnotify/fsnotify dependency

fixes #682
2026-04-21 23:21:48 -07:00
Benson Wong 625b296720 docker/unified: add uv via pip install (#681)
Install uv after the cpp tool binaries are copied and before the
llama-swap binary, enabling `uv run` usage for Python-based inference
backends like vLLM.

- add python3-pip to runtime apt installs
- add `pip install uv --break-system-packages` after cpp installs

fixes #628

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-20 20:55:51 -07:00
Benson Wong 231e62291c proxy: fix matrix race and process stop bug (#677)
- matrix.go change logic to consider any proxy.Process not in
StateStopped or StateShutdown
- process.StopImmediately, and Stop() which called it had a subtle bug
where it only handled state transitions from StateReady to
StateStopping. StateStarting -> StateStopping was ignored completely.

fix: #670
2026-04-20 00:21:11 -07:00
Benson Wong 57ac666598 .github/workflows: tweak push ghcr conditional (#676) 2026-04-19 13:56:26 -07:00
Benson Wong 69728301f5 .github/workflows: add toggle for pushing unified images to github (#672)
Add ability to dispatch (manually run) unified container builds in github without push to ghcr.io.
2026-04-19 10:10:48 -07:00
Benson Wong c176fa70f1 docker/unified: add spirv-headers to fix vulkan build (#669) 2026-04-18 12:18:10 -07:00
Benson Wong 5e3c646829 proxy: compress captures with zstd (#668)
The previous captures were saved uncompressed in memory. In agentic
workflows there can be many turns with each request containing the
previous context in the body with a lot of redundant data. Use zstd to
compress the request and response data before keeping a copy of memory.

Results: 

- Average Percentage Saved: 73.19%
- Average Compression Factor: ~6.77:1
2026-04-17 23:29:37 -07:00
Benson Wong c3f0d43e6e proxy: fix race conditions during swap (#667)
I pointed Opus 4.7 (high effort) at proxy.ProcessGroup to identify any
race conditions in the swapping code. It found a race condition where
there is a small window in the fast path for routing a request to a
loaded model. There is a very small window where:

- model M1 is loaded and ready for requests
- a request, R1, for M1 comes in 
- a request, R2, for M2 comes in almost immediately after
- R1 acquires the lock, sees M1 is loaded (fast path), releases the lock
`[race window]` and the request is ready to be forwarded
- the race window occurs between the release of the lock and the request
being forwarded
  - the lock is released so requests can be handled concurrently 
- R2 comes in within the `[race window]`, acquires the lock, triggers a
model swap to M2. stopping M1
- R1 is forwarded to a model that is unloaded or in the process of
shutting down creating an error response

In deployed systems the race window is very small and doesn't happen
often. However with #635 and PR #656 I though this deserved a bit more
attention. It is not concluded that this race is the cause of #635 but
the race is likely to happen more often under sustained or high load.

AI Note: Opus 4.7 x-high effort took about an hour to write the original
patch. With the pattern discovered the fix to matrix.go was very quick.
GLM 5.1 using the previous established patterns was able to easily write
the fix for ProcessGroup.StopProcesses().

Supersedes: #656
Updates: #277, #635
2026-04-17 21:23:17 -07:00
Benson Wong f6cf9f5844 proxy: Refactor tests (#660)
- use YAML for test configurations
- remove most uses of simple-responder, opting to use
process.testHandler

Fixes #655
2026-04-16 22:47:42 -07:00
Benson Wong 121fd93ad8 Makefile: restore linux arm64 targets
Fix #641
2026-04-14 22:05:39 -07:00
Benson Wong 17233e9278 docs: update configuration.md for matrix 2026-04-14 22:01:03 -07:00
Benson Wong 4866d16c3e README.md: update to use matrix instead of groups 2026-04-14 21:57:49 -07:00
Benson Wong 35193f82f1 proxy: add swap matrix with solver-based model swapping (#646)
Add a new swap matrix to supersede groups for running concurrent models.
The matrix uses a solver that picks the lowest cost evictions to make a
requested model available. This simple approach along with a very basic
DSL grammar can enable very complex swapping scenarios.

- add DSL parser for set expressions with & (AND), | (OR), (), +ref
- add MatrixConfig structs, validation, and topological sort for +ref
- add MatrixSolver with cost-minimizing swap decisions
- add Matrix runtime integrating solver with Process lifecycle
- integrate matrix into ProxyManager with if-branches at all endpoints
- update config.example.yaml and config-schema.json with matrix schema
- config enforces groups XOR matrix (cannot use both)

fixes #643
2026-04-14 21:55:30 -07:00
Benson Wong 40e39f7a86 ui-svelte: fix security issues (#649) 2026-04-12 16:21:31 -07:00
Benson Wong a9d840ffd7 proxy,proxy/config: restore timeouts to pre PR 619 (#648)
Reset the default ResponseHeader timeout to 0 (no timeout) which was set
to 60 seconds in PR #619.

Fixes #647
2026-04-11 20:42:13 -07:00
Benson Wong 7b2b82777f docker/unified: derive rootless image from root container (#644)
Build the root image once, then derive the rootless variant from it
using a small inline Dockerfile that adds the non-root user and chowns
the writable directories. This halves the number of CI jobs (4 → 2) and
eliminates the redundant full CUDA compilation for the rootless variant.

- remove RUN_UID build arg from build-image.sh
- derive rootless image inline after root build completes
- collapse variant matrix out of unified-docker.yml
- push both root and rootless tags in a single CI job

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 22:59:54 -07:00
59 changed files with 5927 additions and 1752 deletions
+46 -47
View File
@@ -2,69 +2,68 @@ name: Linux CI
on: on:
push: push:
branches: [ "main" ] branches: ["main"]
# only run when backend source changes # only run when backend source changes
# cmd/ is excluded because it contains utilities without tests # cmd/ is excluded because it contains utilities without tests
paths: paths:
- '**/*.go' - "**/*.go"
- '!cmd/**' - "!cmd/**"
- 'go.mod' - "go.mod"
- 'go.sum' - "go.sum"
- 'Makefile' - "Makefile"
- '.github/workflows/go-ci.yml' - ".github/workflows/go-ci.yml"
pull_request: pull_request:
branches: [ "main" ] branches: ["main"]
paths: paths:
- '**/*.go' - "**/*.go"
- '!cmd/**' - "!cmd/**"
- 'go.mod' - "go.mod"
- 'go.sum' - "go.sum"
- 'Makefile' - "Makefile"
- '.github/workflows/go-ci.yml' - ".github/workflows/go-ci.yml"
# Allows manual triggering of the workflow # Allows manual triggering of the workflow
workflow_dispatch: workflow_dispatch:
jobs: jobs:
run-tests: run-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v4
with: with:
go-version-file: go.mod go-version-file: go.mod
# Only run in this linux based runner # Only run in this linux based runner
- name: Check Formatting - name: Check Formatting
run: | run: |
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
gofmt -l . | grep -v 'event/.*_test.go' gofmt -l . | grep -v 'event/.*_test.go'
exit 1 exit 1
fi fi
# cache simple-responder to save the build time # cache simple-responder to save the build time
- name: Restore Simple Responder - name: Restore Simple Responder
id: restore-simple-responder id: restore-simple-responder
uses: actions/cache/restore@v4 uses: actions/cache/restore@v4
with: with:
path: ./build path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }} key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping # necessary for testing proxy/Process swapping
- name: Create simple-responder - name: Create simple-responder
run: make simple-responder run: make simple-responder
- name: Save Simple Responder - name: Save Simple Responder
# nothing new to save ... skip this step # nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true' if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder id: save-simple-responder
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
path: ./build path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }} key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- name: Test all - name: Test all
run: make test-all run: make test-all
+2 -11
View File
@@ -19,9 +19,6 @@ jobs:
run-tests: run-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
working-directory: ui-svelte
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -32,11 +29,5 @@ jobs:
cache: 'npm' cache: 'npm'
cache-dependency-path: ui-svelte/package-lock.json cache-dependency-path: ui-svelte/package-lock.json
- name: Install dependencies - name: Run UI tests
run: npm ci run: make test-ui
- name: Type check
run: npm run check
- name: Run tests
run: npm test
+18 -15
View File
@@ -36,6 +36,11 @@ on:
type: boolean type: boolean
required: false required: false
default: true default: true
push_to_ghcr:
description: "Push images to ghcr.io"
type: boolean
required: false
default: true
permissions: permissions:
contents: read contents: read
@@ -68,13 +73,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
backend: ${{ fromJSON(needs.setup.outputs.matrix) }} backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
variant:
- name: root
uid: "0"
suffix: ""
- name: rootless
uid: "10001"
suffix: "-rootless"
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -106,15 +104,14 @@ jobs:
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Build unified Docker image (${{ matrix.backend }}, ${{ matrix.variant.name }}) - name: Build unified Docker image (${{ matrix.backend }})
env: env:
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }} LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }} WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
SD_REF: ${{ inputs.sd_ref || 'master' }} SD_REF: ${{ inputs.sd_ref || 'master' }}
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }} IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }} LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
RUN_UID: ${{ matrix.variant.uid }} DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}${{ matrix.variant.suffix }}
# When running under act, use the local builder that has warm ccache. # When running under act, use the local builder that has warm ccache.
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder # On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
# created by setup-buildx-action above. # created by setup-buildx-action above.
@@ -124,10 +121,16 @@ jobs:
docker/unified/build-image.sh --${{ matrix.backend }} docker/unified/build-image.sh --${{ matrix.backend }}
- name: Push to GitHub Container Registry - name: Push to GitHub Container Registry
if: ${{ !env.ACT }} if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
run: | run: |
TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}${{ matrix.variant.suffix }}" BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
docker push "${TAG}"
DATE_TAG=$(date -u +%Y-%m-%d) DATE_TAG=$(date -u +%Y-%m-%d)
docker tag "${TAG}" "${TAG}-${DATE_TAG}"
docker push "${TAG}-${DATE_TAG}" docker push "${BASE_TAG}"
docker tag "${BASE_TAG}" "${BASE_TAG}-${DATE_TAG}"
docker push "${BASE_TAG}-${DATE_TAG}"
ROOTLESS_TAG="${BASE_TAG}-rootless"
docker push "${ROOTLESS_TAG}"
docker tag "${ROOTLESS_TAG}" "${ROOTLESS_TAG}-${DATE_TAG}"
docker push "${ROOTLESS_TAG}-${DATE_TAG}"
+1
View File
@@ -24,6 +24,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`. - Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests. - Use `make test-all` before completing work. This includes long running concurrency tests.
- Use `make test-ui` after making changes to the UI in ui-svelte/
### Commit message example format: ### Commit message example format:
+13 -4
View File
@@ -48,10 +48,15 @@ mac: ui
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64 GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary # Build Linux binary
linux: ui linux: linux-arm64 linux-amd64
@echo "Building Linux binary..."
linux-amd64: ui
@echo "Building Linux AMD64 binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64 GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
linux-arm64: ui
@echo "Building Linux ARM64 binary..."
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary # Build Windows binary
windows: ui windows: ui
@@ -92,5 +97,9 @@ wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy" @echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
test-ui:
cd ui-svelte && npm ci && npm run check && npm test
# Phony targets # Phony targets
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy .PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
.PHONE: linux linux-arm64 linux-amd64
+7 -9
View File
@@ -5,7 +5,7 @@
# llama-swap # llama-swap
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows. Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file. Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
@@ -45,14 +45,14 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `/health` - just returns "OK" - `/health` - just returns "OK"
- ✅ API Key support - define keys to restrict access to API endpoints - ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable - ✅ Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
- Automatic unloading of models after timeout by setting a `ttl` - Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together - Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) - Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
### Web UI ### Web UI
llama-swap includes a real time web interface with a playground for testing out all sorts of local models: llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" /> <img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
@@ -64,16 +64,14 @@ Inspect request and responses:
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" /> <img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
Manually load and unload models: Manually load and unload models:
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" /> <img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
Real time log streaming:
Real time log streaming:
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" /> <img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
## Installation ## Installation
llama-swap can be installed in multiple ways llama-swap can be installed in multiple ways
@@ -182,7 +180,7 @@ That's all you need to get started:
Almost all configuration settings are optional and can be added one step at a time: Almost all configuration settings are optional and can be added one step at a time:
- Advanced features - Advanced features
- `groups` to run multiple models at once - `matrix` to run concurrent models with a custom swap logic DSL
- `hooks` to run things on startup - `hooks` to run things on startup
- `macros` reusable snippets - `macros` reusable snippets
- Model customization - Model customization
@@ -200,7 +198,7 @@ See the [configuration documentation](docs/configuration.md) for all options.
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly. When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, using a `matrix` allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## Reverse Proxy Configuration (nginx) ## Reverse Proxy Configuration (nginx)
+183
View File
@@ -0,0 +1,183 @@
# Improve Testability (#655)
## Current Pain Points
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
## Stages
### Stage 1: YAML-based test config helper
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
Create a test helper in `proxy/helpers_test.go`:
```go
// testConfigFromYAML substitutes simple-responder paths and loads through
// the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
```
Tests would then look like:
```go
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
model2:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
`)
proxy := New(config)
// ... same assertions
}
```
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
**2a. Add testHandler to Process:**
```go
// In Process struct (process.go):
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
```
In `Process.Start()`, skip subprocess + health check when handler is set:
```go
func (p *Process) start() error {
if p.testHandler != nil {
p.setState(StateReady)
return nil
}
// existing subprocess logic...
}
```
In `Process.ProxyRequest()`, delegate directly to the handler:
```go
// Before the reverseProxy.ServeHTTP call:
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
return
}
```
**2b. Test helper to create the handler:**
```go
// newTestHandler returns an http.Handler that mimics llama.cpp's API
// (same endpoints as simple-responder).
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
// ... other endpoints
return mux
}
```
Tests for routing/auth/CORS/streaming then become:
```go
func TestProxyManager_AuthRequired(t *testing.T) {
handler := newTestHandler("model1")
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
requiredAPIKeys: [test-key]
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
`)
pm := NewProxyManager(config)
// inject handler — skips subprocess, health check, port allocation
pm.processGroups["model1"].process.testHandler = handler
}
```
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
### Stage 3: Migrate tests incrementally
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
Priority order:
1. `proxymanager_test.go` routing tests (highest count, most repetition)
2. `peerproxy_test.go` (straightforward, all HTTP routing)
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
Tests that **must keep simple-responder:**
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
**Scope:** ~60-70% of tests can drop simple-responder.
### Stage 4 (optional): Process interface for ProcessGroup
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
```go
type ProcessController interface {
Start() error
Stop(StopStrategy)
ProxyRequest(http.ResponseWriter, *http.Request) error
CurrentState() ProcessState
ID() string
SetState(ProcessState) // for test setup
}
```
This requires:
- Extracting the interface
- A `MockProcess` implementation
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
## Effort/Impact Summary
| Stage | Effort | Impact | Risk |
|-------|--------|--------|------|
| 1. YAML config helper | Low | Config bugs caught earlier | None |
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
| 4. Process interface | High | Pure unit tests possible | Medium |
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
+84 -12
View File
@@ -47,31 +47,37 @@
"type": "integer", "type": "integer",
"minimum": 0, "minimum": 0,
"default": 30, "default": 30,
"description": "TCP connection timeout in seconds. Set to 0 to disable (not recommended)." "description": "TCP connection timeout in seconds. Set to 0 to disable."
},
"keepalive": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP keepalive timeout in seconds. Set to 0 to disable."
}, },
"responseHeader": { "responseHeader": {
"type": "integer", "type": "integer",
"minimum": 0, "minimum": 0,
"default": 60, "default": 0,
"description": "Time to wait for response headers in seconds. Set to 0 to disable (not recommended)." "description": "Time to wait for response headers in seconds. Set to 0 to disable."
}, },
"tlsHandshake": { "tlsHandshake": {
"type": "integer", "type": "integer",
"minimum": 0, "minimum": 0,
"default": 10, "default": 10,
"description": "TLS handshake timeout in seconds. Set to 0 to disable (not recommended)." "description": "TLS handshake timeout in seconds. Set to 0 to disable."
}, },
"expectContinue": { "expectContinue": {
"type": "integer", "type": "integer",
"minimum": 0, "minimum": 0,
"default": 1, "default": 1,
"description": "Expect-Continue timeout in seconds. Set to 0 to disable (not recommended)." "description": "Expect-Continue timeout in seconds. Set to 0 to disable."
}, },
"idleConn": { "idleConn": {
"type": "integer", "type": "integer",
"minimum": 0, "minimum": 0,
"default": 90, "default": 90,
"description": "Idle connection timeout in seconds. Set to 0 to disable (not recommended)." "description": "Idle connection timeout in seconds. Set to 0 to disable."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@@ -319,6 +325,44 @@
}, },
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent." "description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
}, },
"matrix": {
"type": "object",
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
"required": [
"vars",
"sets"
],
"properties": {
"vars": {
"type": "object",
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
},
"propertyNames": {
"pattern": "^[a-zA-Z0-9]{1,8}$"
}
},
"evict_costs": {
"type": "object",
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
"additionalProperties": {
"type": "integer",
"minimum": 1
}
},
"sets": {
"type": "object",
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
}
}
},
"additionalProperties": false
},
"hooks": { "hooks": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -413,25 +457,31 @@
"properties": { "properties": {
"connect": { "connect": {
"type": "integer", "type": "integer",
"minimum": 1, "minimum": 0,
"default": 30, "default": 30,
"description": "TCP connection timeout in seconds." "description": "TCP connection timeout in seconds."
}, },
"keepalive": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP keepalive connection timeout in seconds."
},
"responseHeader": { "responseHeader": {
"type": "integer", "type": "integer",
"minimum": 1, "minimum": 0,
"default": 60, "default": 0,
"description": "Time to wait for response headers in seconds." "description": "Time to wait for response headers in seconds."
}, },
"tlsHandshake": { "tlsHandshake": {
"type": "integer", "type": "integer",
"minimum": 1, "minimum": 0,
"default": 10, "default": 10,
"description": "TLS handshake timeout in seconds." "description": "TLS handshake timeout in seconds."
}, },
"idleConn": { "idleConn": {
"type": "integer", "type": "integer",
"minimum": 1, "minimum": 0,
"default": 90, "default": 90,
"description": "Idle connection timeout in seconds." "description": "Idle connection timeout in seconds."
} }
@@ -444,5 +494,27 @@
"default": {}, "default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap." "description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
} }
} },
"allOf": [
{
"if": {
"required": ["groups"]
},
"then": {
"not": {
"required": ["matrix"]
}
}
},
{
"if": {
"required": ["matrix"]
},
"then": {
"not": {
"required": ["groups"]
}
}
}
]
} }
+78 -61
View File
@@ -287,14 +287,15 @@ models:
# timeouts: configure proxy connection timeouts for this model # timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below # - optional, defaults shown below
# - useful for models running on slower hardware that need longer timeouts # - useful for models running on slower hardware that need longer timeouts
# - connect: TCP connection timeout in seconds # - connect: TCP dial connection timeout in seconds, default: 30 seconds
# - responseHeader: time to wait for response headers in seconds # - keepalive: TCP connection keepalive timeout, default: 30 seconds
# (increasing this helps avoid 502 errors on slow hardware) # - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
# - tlsHandshake: TLS handshake timeout in seconds # - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
# - idleConn: idle connection timeout in seconds # - idleConn: idle connection timeout in seconds, default: 90 seconds
# - set any value to 0 to disable that timeout (not recommended) # - set any value to 0 to disable that timeout (not recommended)
timeouts: timeouts:
connect: 30 connect: 30
keepalive: 0
responseHeader: 60 responseHeader: 60
tlsHandshake: 10 tlsHandshake: 10
idleConn: 90 idleConn: 90
@@ -330,68 +331,83 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted # - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID} cmdStop: docker stop ${MODEL_ID}
# groups: a dictionary of group settings # =============================================================================
# - optional, default: empty dictionary # matrix: run concurrent models with a solver-based swap DSL
# - provides advanced controls over model swapping behaviour # =============================================================================
# - using groups some models can be kept loaded indefinitely, while others are swapped out
# - model IDs must be defined in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
# #
# NOTE: the example below uses model names that are not defined above for demonstration purposes # Note:
groups: # A config must use either a matrix or legacy groups, not both. A configuration error
# group1 works the same as the default behaviour of llama-swap where only one model is allowed # will occur if both are defined. Configuration examples for legacy Groups can be found:
# to run a time across the whole llama-swap instance # https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
"group1": #
# swap: controls the model swapping behaviour in within the group # The matrix declares valid combinations of models that can run concurrently.
# - optional, default: true # When a model is requested, the solver finds the cheapest way to make it
# - true : only one model is allowed to run at a time # available by evicting as few (and least costly) running models as possible.
# - false: all models can run together, no swapping #
swap: true # Solver behavior:
# 1. Request arrives for model X
# 2. If X is already running, forward immediately. Done.
# 3. Find all sets containing X
# 4. For each candidate set, compute cost: sum of evict_costs for
# every running model NOT in that set
# 5. Pick lowest cost candidate. Ties broken by definition order.
# 6. Evict what needs to stop. Start X. Forward request.
#
# Subset semantics: a set [a, b, c] means any subset is valid.
# Only the requested model is started — others are not preloaded.
#
# A model not appearing in any set can only run alone.
#
matrix:
# vars: short names for models (alphanumeric, 1-8 chars)
# - required for sets and evict_costs settings
# - each entry is a short name to a real model ID. Do not use an alias
# - used to keep set DSL logic short and easier to read
# - sets and evict_costs only use identifiers defined in vars
vars:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# exclusive: controls how the group affects other groups # evict_costs: relative cost of losing a running model (default: 1)
# - optional, default: true evict_costs:
# - true: causes all other groups to unload when this group runs a model v: 50 # vllm backend, slow cold start
# - false: does not affect other groups L: 30 # 70B weights, slow to load
exclusive: true
# members references the models defined above # sets: named sets of concurrent model combinations
# required # Values are DSL strings with operators:
members: # & AND (models run together)
- "llama" # | OR (alternatives)
- "qwen-unlisted" # () grouping
# +ref inline another set's expression
#
# Expansion examples:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → expands llms inline, then applies & v
sets:
# LLM + TTS: switching between g/q/m won't evict v
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# Example: # LLM + TTS + reranker
# - in group2 all models can run at the same time # expands to: [g,v,e], [q,v,e]
# - when a different group is loaded it causes all running models in this group to unload with_rerank: "(g | q) & v & e"
"group2":
swap: false
# exclusive: false does not unload other groups when a model in group2 is requested # LLM + image generation, no TTS
# - the models in group2 will be loaded but will not unload any other groups # expands to: [g,sd], [q,sd]
exclusive: false creative: "(g | q) & sd"
members:
- "docker-llama"
- "modelA"
- "modelB"
# Example: # 70B model uses all GPUs, can only run alone
# - a persistent group, prevents other groups from unloading it # expands to: [L]
"forever": full: "L"
# persistent: prevents over groups from unloading the models in this group
# - optional, default: false
# - does not affect individual model behaviour
persistent: true
# set swap/exclusive to false to prevent swapping inside the group
# and the unloading of other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
# hooks: a dictionary of event triggers and actions # hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary # - optional, default: empty dictionary
@@ -447,6 +463,7 @@ peers:
# - set any value to 0 to disable that timeout (not recommended) # - set any value to 0 to disable that timeout (not recommended)
timeouts: timeouts:
connect: 30 connect: 30
keepalive: 30
responseHeader: 60 responseHeader: 60
tlsHandshake: 10 tlsHandshake: 10
idleConn: 90 idleConn: 90
+5 -1
View File
@@ -42,6 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \ build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget software-properties-common \ curl ca-certificates ccache make wget software-properties-common \
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \ libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
spirv-headers \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
WORKDIR /build WORKDIR /build
@@ -148,7 +149,7 @@ ARG IK_LLAMA_COMMIT_HASH=unknown
ARG RUN_UID=0 ARG RUN_UID=0
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-numpy python3-sentencepiece \ python3-numpy python3-sentencepiece python3-pip \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Create non-root user when RUN_UID != 0 # Create non-root user when RUN_UID != 0
@@ -179,6 +180,9 @@ COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
# Copy ik-llama-server (CUDA only; empty copy for vulkan) # Copy ik-llama-server (CUDA only; empty copy for vulkan)
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/ COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
# Install uv
RUN pip install uv --break-system-packages
# Copy llama-swap binary # Copy llama-swap binary
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/ COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
COPY --from=llama-swap-download /install/llama-swap-version /tmp/ COPY --from=llama-swap-download /install/llama-swap-version /tmp/
+22 -2
View File
@@ -201,7 +201,6 @@ BUILD_ARGS=(
--build-arg "SD_COMMIT_HASH=${SD_HASH}" --build-arg "SD_COMMIT_HASH=${SD_HASH}"
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}" --build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
--build-arg "LS_VERSION=${LS_HASH}" --build-arg "LS_VERSION=${LS_HASH}"
--build-arg "RUN_UID=${RUN_UID:-0}"
-t "${DOCKER_IMAGE_TAG}" -t "${DOCKER_IMAGE_TAG}"
-f "${SCRIPT_DIR}/Dockerfile" -f "${SCRIPT_DIR}/Dockerfile"
) )
@@ -255,12 +254,33 @@ if [[ "$BACKEND" == "cuda" ]]; then
fi fi
echo "All expected binaries verified: ${VERIFIED_LIST}" echo "All expected binaries verified: ${VERIFIED_LIST}"
echo ""
echo "=========================================="
echo "Building rootless image..."
echo "=========================================="
echo ""
ROOTLESS_TAG="${DOCKER_IMAGE_TAG}-rootless"
docker buildx build --load -t "${ROOTLESS_TAG}" - <<EOF
FROM ${DOCKER_IMAGE_TAG}
USER root
RUN groupadd --system --gid 10001 llama-swap && \\
useradd --system --uid 10001 --gid 10001 \\
--home /app --shell /sbin/nologin llama-swap && \\
chown -R 10001:10001 /etc/llama-swap /models
USER 10001
EOF
echo "Rootless image built: ${ROOTLESS_TAG}"
echo "" echo ""
echo "==========================================" echo "=========================================="
echo "Build complete!" echo "Build complete!"
echo "==========================================" echo "=========================================="
echo "" echo ""
echo "Image tag: ${DOCKER_IMAGE_TAG}" echo "Image tags:"
echo " ${DOCKER_IMAGE_TAG}"
echo " ${ROOTLESS_TAG}"
echo "" echo ""
echo "Built with:" echo "Built with:"
echo " llama.cpp: ${LLAMA_HASH}" echo " llama.cpp: ${LLAMA_HASH}"
+10 -2
View File
@@ -38,8 +38,16 @@ if [ "$VERSION" = "latest" ]; then
echo "Latest version: ${VERSION}" echo "Latest version: ${VERSION}"
fi fi
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) echo "FATAL: Unsupported architecture: $ARCH" >&2; exit 1 ;;
esac
# Download and extract # Download and extract
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz" URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_${ARCH}.tar.gz"
echo "=== Downloading llama-swap v${VERSION} ===" echo "=== Downloading llama-swap v${VERSION} ==="
echo "URL: $URL" echo "URL: $URL"
curl -fSL -o /tmp/llama-swap.tar.gz "$URL" curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
@@ -56,4 +64,4 @@ fi
echo "$VERSION" > /install/llama-swap-version echo "$VERSION" > /install/llama-swap-version
echo "=== llama-swap v${VERSION} installed ===" echo "=== llama-swap v${VERSION} installed ==="
ls -la /install/bin/llama-swap ls -la /install/bin/llama-swap
+199 -118
View File
@@ -22,7 +22,7 @@ models:
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
``` ```
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature. With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model which is useful if you want to run multiple models at the same time with the `matrix` feature.
## Advanced control with `cmd` ## Advanced control with `cmd`
@@ -76,7 +76,7 @@ llama-swap supports many more features to customize how you want to manage your
| --------- | ---------------------------------------------- | | --------- | ---------------------------------------------- |
| `ttl` | automatic unloading of models after a timeout | | `ttl` | automatic unloading of models after a timeout |
| `macros` | reusable snippets to use in configurations | | `macros` | reusable snippets to use in configurations |
| `groups` | run multiple models at a time | | `matrix` | run multiple models at a time |
| `hooks` | event driven functionality | | `hooks` | event driven functionality |
| `env` | define environment variables per model | | `env` | define environment variables per model |
| `aliases` | serve a model with different names | | `aliases` | serve a model with different names |
@@ -141,6 +141,11 @@ logToStdout: "proxy"
# - useful for limiting memory usage when processing large volumes of metrics # - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000 metricsMaxInMemory: 1000
# captureBuffer: how many MBs to allocate for storing request/response captures
# - optional, default: 10
# - set to 0 to disable
captureBuffer: 15
# startPort: sets the starting port number for the automatic ${PORT} macro. # startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800 # - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings # - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -161,15 +166,10 @@ sendLoadingState: true
# all fields except for Id so chat UIs can use the alias equivalent to the original. # all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false includeAliasesInList: false
# apiKeys: require an API key when making requests to inference endpoints # globalTTL: the default TTL in seconds before unloading a model
# - optional, default: [] # - optional, default: 0 (never automatically unload)
# - when empty (the default) authorization will not be checked as llama-swap is default-allow # - must be >= 0
# - each key is a non-empty string globalTTL: 0
apiKeys:
- "sk-hunter2"
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
# macros: a dictionary of string substitutions # macros: a dictionary of string substitutions
# - optional, default: empty dictionary # - optional, default: empty dictionary
@@ -181,6 +181,9 @@ apiKeys:
# - macro names must not be a reserved name: PORT or MODEL_ID # - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings # - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used # - macros can contain other macros, but they must be defined before they are used
# - environment variables can be referenced with ${env.VAR_NAME} syntax
# - env macros are substituted first, before regular macros
# - if the env var is not set, config loading will fail with an error
macros: macros:
# Example of a multi-line macro # Example of a multi-line macro
"latest-llama": > "latest-llama": >
@@ -193,6 +196,24 @@ macros:
# but they must be previously declared. # but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}" "default_args": "--ctx-size ${default_ctx}"
# Example of environment variable macros
# - ${env.VAR_NAME} pulls the value from the system environment
# - useful for paths, secrets, or machine-specific configuration
"models_dir": "${env.HOME}/models"
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
# use environment variable macros to keep secrets out of the config
- "${env.API_KEY_1}"
- "${env.API_KEY_2}"
# models: a dictionary of model configurations # models: a dictionary of model configurations
# - required # - required
# - each key is the model's ID, used in API requests # - each key is the model's ID, used in API requests
@@ -201,7 +222,7 @@ macros:
# - below are examples of the all the settings a model can have # - below are examples of the all the settings a model can have
models: models:
# keys are the model names used in API requests # keys are the model names used in API requests
"llama": "gpt-oss-120b":
# macros: a dictionary of string substitutions specific to this model # macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section # - macros defined here override macros defined in the global macros section
@@ -218,7 +239,7 @@ models:
cmd: | cmd: |
# ${latest-llama} is a macro that is defined above # ${latest-llama} is a macro that is defined above
${latest-llama} ${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf --model path/to/gpt-oss-120B.gguf
--ctx-size ${default_ctx} --ctx-size ${default_ctx}
--temperature ${temp} --temperature ${temp}
@@ -226,13 +247,13 @@ models:
# - optional, default: empty string # - optional, default: empty string
# - if set, it will be used in the v1/models API response # - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record # - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B" name: "gpt-oss 120B"
# description: a description for the model # description: a description for the model
# - optional, default: empty string # - optional, default: empty string
# - if set, it will be used in the v1/models API response # - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record # - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing" description: "A thinking model from OpenAI"
# env: define an array of environment variables to inject into cmd's environment # env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array # - optional, default: empty array
@@ -247,14 +268,6 @@ models:
# - if you use a custom port in cmd this *must* be set # - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# checkEndpoint: URL path to check if the server is ready # checkEndpoint: URL path to check if the server is ready
# - optional, default: /health # - optional, default: /health
# - endpoint is expected to return an HTTP 200 response # - endpoint is expected to return an HTTP 200 response
@@ -263,8 +276,10 @@ models:
checkEndpoint: /custom-endpoint checkEndpoint: /custom-endpoint
# ttl: automatically unload the model after ttl seconds # ttl: automatically unload the model after ttl seconds
# - optional, default: 0 # - optional, default: -1 (use global default)
# - ttl values must be a value greater than 0 # - ttl values must be a value greater than or equal to 0
# - a ttl of -1 will use the global TTL value as the default
# - a ttl of 0 will mean never unload
# - a value of 0 disables automatic unloading of the model # - a value of 0 disables automatic unloading of the model
ttl: 60 ttl: 60
@@ -272,11 +287,11 @@ models:
# - optional, default: "" # - optional, default: ""
# - useful for when the upstream server expects a specific model name that # - useful for when the upstream server expects a specific model name that
# is different from the model's ID # is different from the model's ID
useModelName: "qwen:qwq" useModelName: "openai/gpt-oss-120B"
# filters: a dictionary of filter settings # filters: a dictionary of filter settings
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - only stripParams is currently supported # - same capabilities as peer filters (stripParams, setParams)
filters: filters:
# stripParams: a comma separated list of parameters to remove from the request # stripParams: a comma separated list of parameters to remove from the request
# - optional, default: "" # - optional, default: ""
@@ -286,6 +301,43 @@ models:
# - recommended to stick to sampling parameters # - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k" stripParams: "temperature, top_p, top_k"
# setParams: a dictionary of parameters to set/override in requests
# - optional, default: empty dictionary
# - useful for enforcing specific parameter values
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - always runs for the model
setParams:
# Example: enforce specific sampling parameters
temperature: 0.7
top_p: 0.9
# setParamsByID: a dictionary of parameters to set based the model ID
# - optional, default: empty dictionary
# - combine with aliases to create variant behaviour without reloading the model
# - parameters are set in the request body JSON
# - run after setParams so it will override any settings
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - model aliases will be automatically created for each key
setParamsByID:
"${MODEL_ID}":
chat_template_kwargs:
reasoning_effort: medium
"${MODEL_ID}:high":
chat_template_kwargs:
reasoning_effort: high
"${MODEL_ID}:low":
chat_template_kwargs:
reasoning_effort: low
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
# metadata: a dictionary of arbitrary values that are included in /v1/models # metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple # - while metadata can contains complex types it is recommended to keep it simple
@@ -319,33 +371,26 @@ models:
# - recommended to be omitted and the default used # - recommended to be omitted and the default used
concurrencyLimit: 0 concurrencyLimit: 0
# timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below
# - useful for models on slower hardware that need longer timeouts
# - increase responseHeader to avoid "timeout awaiting response headers" errors
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
# connect: TCP connection timeout in seconds
# - default: 30
connect: 30
# responseHeader: time to wait for response headers in seconds
# - default: 60
# - for slow image generation or large models, consider increasing to 300+ seconds
responseHeader: 60
# tlsHandshake: TLS handshake timeout in seconds
# - default: 10
tlsHandshake: 10
# idleConn: idle connection timeout in seconds
# - default: 90
idleConn: 90
# sendLoadingState: overrides the global sendLoadingState setting for this model # sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting) # - optional, default: undefined (use global setting)
sendLoadingState: false sendLoadingState: false
# timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below
# - useful for models running on slower hardware that need longer timeouts
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
# - idleConn: idle connection timeout in seconds, default: 90 seconds
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 0
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# Unlisted model example: # Unlisted model example:
"qwen-unlisted": "qwen-unlisted":
# unlisted: boolean, true or false # unlisted: boolean, true or false
@@ -377,68 +422,83 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted # - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID} cmdStop: docker stop ${MODEL_ID}
# groups: a dictionary of group settings # =============================================================================
# - optional, default: empty dictionary # matrix: run concurrent models with a solver-based swap DSL
# - provides advanced controls over model swapping behaviour # =============================================================================
# - using groups some models can be kept loaded indefinitely, while others are swapped out
# - model IDs must be defined in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
# #
# NOTE: the example below uses model names that are not defined above for demonstration purposes # Note:
groups: # A config must use either a matrix or legacy groups, not both. A configuration error
# group1 works the same as the default behaviour of llama-swap where only one model is allowed # will occur if both are defined. Configuration examples for legacy Groups can be found:
# to run a time across the whole llama-swap instance # https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
"group1": #
# swap: controls the model swapping behaviour in within the group # The matrix declares valid combinations of models that can run concurrently.
# - optional, default: true # When a model is requested, the solver finds the cheapest way to make it
# - true : only one model is allowed to run at a time # available by evicting as few (and least costly) running models as possible.
# - false: all models can run together, no swapping #
swap: true # Solver behavior:
# 1. Request arrives for model X
# 2. If X is already running, forward immediately. Done.
# 3. Find all sets containing X
# 4. For each candidate set, compute cost: sum of evict_costs for
# every running model NOT in that set
# 5. Pick lowest cost candidate. Ties broken by definition order.
# 6. Evict what needs to stop. Start X. Forward request.
#
# Subset semantics: a set [a, b, c] means any subset is valid.
# Only the requested model is started — others are not preloaded.
#
# A model not appearing in any set can only run alone.
#
matrix:
# vars: short names for models (alphanumeric, 1-8 chars)
# - required for sets and evict_costs settings
# - each entry is a short name to a real model ID. Do not use an alias
# - used to keep set DSL logic short and easier to read
# - sets and evict_costs only use identifiers defined in vars
vars:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# exclusive: controls how the group affects other groups # evict_costs: relative cost of losing a running model (default: 1)
# - optional, default: true evict_costs:
# - true: causes all other groups to unload when this group runs a model v: 50 # vllm backend, slow cold start
# - false: does not affect other groups L: 30 # 70B weights, slow to load
exclusive: true
# members references the models defined above # sets: named sets of concurrent model combinations
# required # Values are DSL strings with operators:
members: # & AND (models run together)
- "llama" # | OR (alternatives)
- "qwen-unlisted" # () grouping
# +ref inline another set's expression
#
# Expansion examples:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → expands llms inline, then applies & v
sets:
# LLM + TTS: switching between g/q/m won't evict v
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# Example: # LLM + TTS + reranker
# - in group2 all models can run at the same time # expands to: [g,v,e], [q,v,e]
# - when a different group is loaded it causes all running models in this group to unload with_rerank: "(g | q) & v & e"
"group2":
swap: false
# exclusive: false does not unload other groups when a model in group2 is requested # LLM + image generation, no TTS
# - the models in group2 will be loaded but will not unload any other groups # expands to: [g,sd], [q,sd]
exclusive: false creative: "(g | q) & sd"
members:
- "docker-llama"
- "modelA"
- "modelB"
# Example: # 70B model uses all GPUs, can only run alone
# - a persistent group, prevents other groups from unloading it # expands to: [L]
"forever": full: "L"
# persistent: prevents over groups from unloading the models in this group
# - optional, default: false
# - does not affect individual model behaviour
persistent: true
# set swap/exclusive to false to prevent swapping inside the group
# and the unloading of other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
# hooks: a dictionary of event triggers and actions # hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary # - optional, default: empty dictionary
@@ -467,17 +527,6 @@ peers:
# - required # - required
# - requested path to llama-swap will be appended to the end of the proxy value # - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23 proxy: http://192.168.1.23
# timeouts: configure proxy connection timeouts for this peer
# - optional, defaults shown below
# - useful when the peer runs on slower hardware
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# models: a list of models served by the peer # models: a list of models served by the peer
# - required # - required
models: models:
@@ -490,7 +539,8 @@ peers:
# - optional, default: "" # - optional, default: ""
# - if blank, no key will be added to the request # - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key> # - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
apiKey: sk-your-openrouter-key # - can be a string or a macro
apiKey: ${env.OPENROUTER_API_KEY}
models: models:
- meta-llama/llama-3.1-8b-instruct - meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507 - qwen/qwen3-235b-a22b-2507
@@ -498,4 +548,35 @@ peers:
- z-ai/glm-4.7 - z-ai/glm-4.7
- moonshotai/kimi-k2-0905 - moonshotai/kimi-k2-0905
- minimax/minimax-m2.1 - minimax/minimax-m2.1
# timeouts: configure proxy connection timeouts for this peer
# - optional, defaults shown below
# - useful when the peer runs on slower hardware
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 30
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# filters: a dictionary of filter settings for peer requests
# - optional, default: empty dictionary
# - same capabilities as model filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for removing parameters that the peer doesn't support
# - the `model` parameter can never be removed
stripParams: "temperature, top_p"
# setParams: a dictionary of parameters to set/override in requests to this peer
# - optional, default: empty dictionary
# - useful for injecting provider-specific settings like data retention policies
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce zero-data-retention for OpenRouter
provider:
data_collection: "deny"
zdr: true
``` ```
+3 -1
View File
@@ -4,8 +4,9 @@ go 1.26.1
require ( require (
github.com/billziss-gh/golib v0.2.0 github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0 github.com/fxamacker/cbor/v2 v2.9.1
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
@@ -36,6 +37,7 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.45.0 // indirect golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect golang.org/x/net v0.47.0 // indirect
+6 -2
View File
@@ -11,8 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -34,6 +34,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
@@ -77,6 +79,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
+46 -40
View File
@@ -9,14 +9,15 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime"
"syscall" "syscall"
"time" "time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/mostlygeek/llama-swap/proxy/configwatcher"
) )
var ( var (
@@ -79,6 +80,17 @@ func main() {
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Reload signals (SIGHUP on POSIX, none on Windows — Windows does not
// deliver SIGHUP). Always wired up so `kill -HUP` works regardless of
// --watch-config.
reloadChan := make(chan os.Signal, 1)
if runtime.GOOS != "windows" {
signal.Notify(reloadChan, syscall.SIGHUP)
}
// Context that bounds the lifetime of background watcher goroutines.
watcherCtx, watcherCancel := context.WithCancel(context.Background())
// Create server with initial handler // Create server with initial handler
srv := &http.Server{ srv := &http.Server{
Addr: *listenStr, Addr: *listenStr,
@@ -121,52 +133,45 @@ func main() {
// load the initial proxy manager // load the initial proxy manager
reloadProxyManager() reloadProxyManager()
debouncedReload := debounce(time.Second, reloadProxyManager) debouncedReload := debounce(time.Second, reloadProxyManager)
if *watchConfig {
defer event.On(func(e proxy.ConfigFileChangedEvent) {
if e.ReloadingState == proxy.ReloadingStateStart {
debouncedReload()
}
})()
fmt.Println("Watching Configuration for changes") // Listen for ConfigFileChangedEvent unconditionally so SIGHUP and the
// poll-based watcher both feed the same debounced reload pipeline. The
// UI also listens for the matching ReloadingStateEnd emitted from
// reloadProxyManager.
defer event.On(func(e proxy.ConfigFileChangedEvent) {
if e.ReloadingState == proxy.ReloadingStateStart {
debouncedReload()
}
})()
// SIGHUP (or platform-equivalent) → reload. Back-to-back signals collapse
// to one reload via the debounce window, which is the desired behavior.
go func() {
for range reloadChan {
fmt.Println("Received reload signal, reloading configuration")
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
}
}()
if *watchConfig {
go func() { go func() {
absConfigPath, err := filepath.Abs(*configPath) absConfigPath, err := filepath.Abs(*configPath)
if err != nil { if err != nil {
fmt.Printf("Error getting absolute path for watching config file: %v\n", err) fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
return return
} }
watcher, err := fsnotify.NewWatcher() fmt.Println("Watching configuration for changes (poll-based, 2s interval)")
if err != nil { (&configwatcher.Watcher{
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err) Path: absConfigPath,
return Interval: configwatcher.DefaultInterval,
} OnChange: func() {
event.Emit(proxy.ConfigFileChangedEvent{
configDir := filepath.Dir(absConfigPath) ReloadingState: proxy.ReloadingStateStart,
err = watcher.Add(configDir) })
if err != nil { },
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err) }).Run(watcherCtx)
return
}
defer watcher.Close()
for {
select {
case changeEvent := <-watcher.Events:
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
// the change for k8s configmap
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
}
case err := <-watcher.Errors:
log.Printf("File watcher error: %v", err)
}
}
}() }()
} }
@@ -174,6 +179,7 @@ func main() {
go func() { go func() {
sig := <-sigChan sig := <-sigChan
fmt.Printf("Received signal %v, shutting down...\n", sig) fmt.Printf("Received signal %v, shutting down...\n", sig)
watcherCancel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
+102
View File
@@ -0,0 +1,102 @@
package cache
import (
"errors"
"sync"
)
var (
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
ErrNotFound = errors.New("item not found")
)
type Cache struct {
mu sync.Mutex
items map[int][]byte
order []int
size int
maxSize int
}
func New(maxBytes int) *Cache {
return &Cache{
items: make(map[int][]byte),
order: make([]int, 0),
maxSize: maxBytes,
}
}
func (c *Cache) Add(id int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
dataSize := len(data)
if dataSize > c.maxSize {
return ErrExceedsMaxSize
}
// If key already exists, remove old entry from size and order
if old, exists := c.items[id]; exists {
c.size -= len(old)
c.removeOrder(id)
}
// Evict oldest (FIFO) until room available
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
oldestID := c.order[0]
c.order = c.order[1:]
if evicted, exists := c.items[oldestID]; exists {
c.size -= len(evicted)
delete(c.items, oldestID)
}
}
c.items[id] = data
c.order = append(c.order, id)
c.size += dataSize
return nil
}
func (c *Cache) removeOrder(id int) {
for i, v := range c.order {
if v == id {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
func (c *Cache) Get(id int) ([]byte, error) {
c.mu.Lock()
defer c.mu.Unlock()
data, exists := c.items[id]
if !exists {
return nil, ErrNotFound
}
return data, nil
}
func (c *Cache) Has(id int) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.items[id]
return exists
}
func (c *Cache) Size() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.size
}
func (c *Cache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[int][]byte)
c.order = c.order[:0]
c.size = 0
}
+130
View File
@@ -0,0 +1,130 @@
package cache
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCache_Add(t *testing.T) {
t.Run("adds and retrieves item", func(t *testing.T) {
c := New(1024)
data := []byte("hello")
require.NoError(t, c.Add(1, data))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, data, got)
})
t.Run("returns error for oversized item", func(t *testing.T) {
c := New(10)
err := c.Add(1, make([]byte, 20))
assert.ErrorIs(t, err, ErrExceedsMaxSize)
})
t.Run("evicts oldest items to make room", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, make([]byte, 40)))
require.NoError(t, c.Add(2, make([]byte, 40)))
// Adding item 3 should evict item 1
require.NoError(t, c.Add(3, make([]byte, 40)))
assert.False(t, c.Has(1))
assert.True(t, c.Has(2))
assert.True(t, c.Has(3))
})
t.Run("overwrites existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("old")))
require.NoError(t, c.Add(1, []byte("new")))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, []byte("new"), got)
assert.Equal(t, 3, c.Size())
})
}
func TestCache_Get(t *testing.T) {
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
c := New(100)
_, err := c.Get(99)
assert.ErrorIs(t, err, ErrNotFound)
})
}
func TestCache_Has(t *testing.T) {
t.Run("returns true for existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("data")))
assert.True(t, c.Has(1))
})
t.Run("returns false for missing key", func(t *testing.T) {
c := New(100)
assert.False(t, c.Has(1))
})
}
func TestCache_Size(t *testing.T) {
t.Run("tracks byte usage", func(t *testing.T) {
c := New(1000)
assert.Equal(t, 0, c.Size())
require.NoError(t, c.Add(1, make([]byte, 100)))
assert.Equal(t, 100, c.Size())
require.NoError(t, c.Add(2, make([]byte, 200)))
assert.Equal(t, 300, c.Size())
})
t.Run("updates on eviction", func(t *testing.T) {
c := New(150)
require.NoError(t, c.Add(1, make([]byte, 100)))
require.NoError(t, c.Add(2, make([]byte, 100)))
// Item 1 should be evicted, size = 100
assert.Equal(t, 100, c.Size())
})
}
func TestCache_Clear(t *testing.T) {
t.Run("removes all items and resets size", func(t *testing.T) {
c := New(1000)
require.NoError(t, c.Add(1, []byte("a")))
require.NoError(t, c.Add(2, []byte("b")))
c.Clear()
assert.Equal(t, 0, c.Size())
assert.False(t, c.Has(1))
assert.False(t, c.Has(2))
})
}
func TestCache_Concurrent(t *testing.T) {
t.Run("concurrent operations are safe", func(t *testing.T) {
c := New(10000)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 100; j++ {
key := id*100 + j
_ = c.Add(key, []byte("data"))
_, _ = c.Get(key)
_ = c.Has(key)
_ = c.Size()
}
}(i)
}
wg.Wait()
})
}
+32 -16
View File
@@ -129,6 +129,12 @@ type Config struct {
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// swap matrix: solver-based alternative to groups
Matrix *MatrixConfig `yaml:"matrix"`
// populated during validation when matrix is configured
ExpandedSets []ExpandedSet `yaml:"-"`
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint // for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"` Macros MacroList `yaml:"macros"`
@@ -438,22 +444,35 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.Models[modelId] = modelConfig config.Models[modelId] = modelConfig
} }
config = AddDefaultGroupToConfig(config) // groups XOR matrix
if config.Matrix != nil && len(config.Groups) > 0 {
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
}
// Validate group members if config.Matrix != nil {
memberUsage := make(map[string]string) expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
for groupID, groupConfig := range config.Groups { if err != nil {
prevSet := make(map[string]bool) return Config{}, fmt.Errorf("matrix: %w", err)
for _, member := range groupConfig.Members { }
if _, found := prevSet[member]; found { config.ExpandedSets = expandedSets
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID) } else {
} config = AddDefaultGroupToConfig(config)
prevSet[member] = true
if existingGroup, exists := memberUsage[member]; exists { // Validate group members
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID) memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
} }
memberUsage[member] = groupID
} }
} }
@@ -627,9 +646,6 @@ func validateMacro(name string, value any) error {
// Validate that value is a scalar type // Validate that value is a scalar type
switch v := value.(type) { switch v := value.(type) {
case string: case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
// Check for self-reference // Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name) macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) { if strings.Contains(v, macroSlug) {
+13 -28
View File
@@ -163,6 +163,15 @@ groups:
modelLoadingState := false modelLoadingState := false
defaultTimeout := TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
}
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
@@ -187,13 +196,7 @@ groups:
Name: "Model 1", Name: "Model 1",
Description: "This is model 1", Description: "This is model 1",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
@@ -202,13 +205,7 @@ groups:
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
"model3": { "model3": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -217,13 +214,7 @@ groups:
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
"model4": { "model4": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -232,13 +223,7 @@ groups:
Aliases: []string{}, Aliases: []string{},
Env: []string{}, Env: []string{},
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
+1 -1
View File
@@ -1475,7 +1475,7 @@ models:
// Default values should be set during unmarshaling // Default values should be set during unmarshaling
assert.Equal(t, 30, modelConfig.Timeouts.Connect) assert.Equal(t, 30, modelConfig.Timeouts.Connect)
assert.Equal(t, 60, modelConfig.Timeouts.ResponseHeader) assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
assert.Equal(t, 10, modelConfig.Timeouts.TLSHandshake) assert.Equal(t, 10, modelConfig.Timeouts.TLSHandshake)
assert.Equal(t, 1, modelConfig.Timeouts.ExpectContinue) assert.Equal(t, 1, modelConfig.Timeouts.ExpectContinue)
assert.Equal(t, 90, modelConfig.Timeouts.IdleConn) assert.Equal(t, 90, modelConfig.Timeouts.IdleConn)
+13 -28
View File
@@ -155,6 +155,15 @@ groups:
modelLoadingState := false modelLoadingState := false
defaultTimeout := TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
}
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
@@ -173,13 +182,7 @@ groups:
Env: []string{"VAR1=value1", "VAR2=value2"}, Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health", CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
@@ -189,13 +192,7 @@ groups:
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
"model3": { "model3": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -205,13 +202,7 @@ groups:
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
"model4": { "model4": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -221,13 +212,7 @@ groups:
Aliases: []string{}, Aliases: []string{},
Env: []string{}, Env: []string{},
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: TimeoutsConfig{ Timeouts: defaultTimeout,
Connect: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
+226
View File
@@ -0,0 +1,226 @@
package config
import (
"fmt"
"regexp"
"sort"
"gopkg.in/yaml.v3"
)
var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
// MatrixConfig represents the swap matrix configuration block.
type MatrixConfig struct {
Var map[string]string `yaml:"vars"`
EvictCosts map[string]int `yaml:"evict_costs"`
Sets OrderedSets `yaml:"sets"`
}
// SetEntry is a single named set with its DSL expression.
type SetEntry struct {
Name string
DSL string
}
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
type OrderedSets []SetEntry
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("sets must be a mapping")
}
entries := make([]SetEntry, 0, len(value.Content)/2)
for i := 0; i < len(value.Content); i += 2 {
keyNode := value.Content[i]
valueNode := value.Content[i+1]
var name string
if err := keyNode.Decode(&name); err != nil {
return fmt.Errorf("failed to decode set name: %w", err)
}
var dsl string
if err := valueNode.Decode(&dsl); err != nil {
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
}
entries = append(entries, SetEntry{Name: name, DSL: dsl})
}
*os = entries
return nil
}
// ExpandedSet is one valid combination of concurrent models (real model names).
type ExpandedSet struct {
SetName string
DSL string
Models []string // real model names, sorted
}
// ValidateMatrix validates the matrix config and returns all expanded sets.
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
if len(matrix.Sets) == 0 {
return nil, fmt.Errorf("matrix must define at least one set")
}
if len(matrix.Var) == 0 {
return nil, fmt.Errorf("matrix must define at least one var")
}
// Validate var entries
if matrix.Var != nil {
for id, modelName := range matrix.Var {
if !varKeyPattern.MatchString(id) {
return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
}
if _, exists := models[modelName]; !exists {
return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
}
}
}
// Validate evict_costs
if matrix.EvictCosts != nil {
for key, cost := range matrix.EvictCosts {
if cost <= 0 {
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
}
if _, ok := matrix.Var[key]; !ok {
return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
}
}
}
// Build dependency graph for +ref topological sort
setNames := make(map[string]bool)
for _, entry := range matrix.Sets {
setNames[entry.Name] = true
}
deps := make(map[string][]string) // setName -> set names it depends on
for _, entry := range matrix.Sets {
refs, err := extractRefs(entry.DSL)
if err != nil {
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
}
for _, ref := range refs {
if !setNames[ref] {
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
}
}
deps[entry.Name] = refs
}
// Topological sort with cycle detection
order, err := topologicalSort(matrix.Sets, deps)
if err != nil {
return nil, err
}
// Expand sets in topological order
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
var allExpanded []ExpandedSet
totalCombinations := 0
// Build ordered map for efficient lookup
setDSL := make(map[string]string)
for _, entry := range matrix.Sets {
setDSL[entry.Name] = entry.DSL
}
for _, name := range order {
dsl := setDSL[name]
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
if err != nil {
return nil, fmt.Errorf("set %q: %w", name, err)
}
resolvedRefs[name] = combos
// Resolve var IDs to real model names
for _, combo := range combos {
resolved := make([]string, len(combo))
for i, ident := range combo {
realName, ok := matrix.Var[ident]
if !ok {
return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
}
resolved[i] = realName
}
sort.Strings(resolved)
allExpanded = append(allExpanded, ExpandedSet{
SetName: name,
DSL: dsl,
Models: resolved,
})
}
totalCombinations += len(combos)
if totalCombinations > maxDSLExpansions {
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
}
}
return allExpanded, nil
}
// topologicalSort returns set names in dependency order.
// Returns an error if a cycle is detected.
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
// States: 0 = unvisited, 1 = visiting, 2 = visited
state := make(map[string]int)
var order []string
var visit func(name string) error
visit = func(name string) error {
switch state[name] {
case 1:
return fmt.Errorf("circular reference detected involving set %q", name)
case 2:
return nil
}
state[name] = 1
for _, dep := range deps[name] {
if err := visit(dep); err != nil {
return err
}
}
state[name] = 2
order = append(order, name)
return nil
}
// Visit in definition order for deterministic output
for _, entry := range sets {
if state[entry.Name] == 0 {
if err := visit(entry.Name); err != nil {
return nil, err
}
}
}
return order, nil
}
// ResolvedEvictCosts returns a map of real model name -> evict cost,
// resolving var IDs. Models not listed default to 1.
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
costs := make(map[string]int)
if m.EvictCosts == nil {
return costs
}
for key, cost := range m.EvictCosts {
// Resolve var ID if present
if realName, ok := m.Var[key]; ok {
costs[realName] = cost
} else {
costs[key] = cost
}
}
return costs
}
+376
View File
@@ -0,0 +1,376 @@
package config
import (
"fmt"
"sort"
"strings"
"unicode"
)
const maxDSLExpansions = 1000
// Token types for the DSL lexer
type tokenType int
const (
tokIdent tokenType = iota // model alias or name
tokAnd // &
tokOr // |
tokLParen // (
tokRParen // )
tokRef // +setName
tokEOF
)
type token struct {
typ tokenType
val string
}
// tokenize splits a DSL string into tokens.
func tokenize(input string) ([]token, error) {
var tokens []token
i := 0
runes := []rune(input)
for i < len(runes) {
ch := runes[i]
// skip whitespace
if unicode.IsSpace(ch) {
i++
continue
}
switch ch {
case '&':
tokens = append(tokens, token{tokAnd, "&"})
i++
case '|':
tokens = append(tokens, token{tokOr, "|"})
i++
case '(':
tokens = append(tokens, token{tokLParen, "("})
i++
case ')':
tokens = append(tokens, token{tokRParen, ")"})
i++
case '+':
// +ref: read the identifier that follows
i++
start := i
for i < len(runes) && isIdentChar(runes[i]) {
i++
}
if i == start {
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
}
tokens = append(tokens, token{tokRef, string(runes[start:i])})
default:
if isIdentChar(ch) {
start := i
for i < len(runes) && isIdentChar(runes[i]) {
i++
}
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
} else {
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
}
}
}
tokens = append(tokens, token{tokEOF, ""})
return tokens, nil
}
func isIdentChar(ch rune) bool {
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
}
// AST node types
type dslNode interface {
dslNode()
}
type andNode struct {
children []dslNode
}
type orNode struct {
children []dslNode
}
type leafNode struct {
name string
}
type refNode struct {
setName string
}
func (andNode) dslNode() {}
func (orNode) dslNode() {}
func (leafNode) dslNode() {}
func (refNode) dslNode() {}
// parser holds state for recursive-descent parsing.
type parser struct {
tokens []token
pos int
}
func (p *parser) peek() token {
if p.pos < len(p.tokens) {
return p.tokens[p.pos]
}
return token{tokEOF, ""}
}
func (p *parser) next() token {
t := p.peek()
if t.typ != tokEOF {
p.pos++
}
return t
}
func (p *parser) expect(typ tokenType) (token, error) {
t := p.next()
if t.typ != typ {
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
}
return t, nil
}
// Grammar:
//
// expr = andExpr
// andExpr = orExpr ('&' orExpr)*
// orExpr = atom ('|' atom)*
// atom = ident | '+' ident | '(' expr ')'
//
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
func parse(tokens []token) (dslNode, error) {
p := &parser{tokens: tokens}
node, err := p.parseExpr()
if err != nil {
return nil, err
}
if p.peek().typ != tokEOF {
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
}
return node, nil
}
func (p *parser) parseExpr() (dslNode, error) {
return p.parseOrExpr()
}
func (p *parser) parseOrExpr() (dslNode, error) {
left, err := p.parseAndExpr()
if err != nil {
return nil, err
}
if p.peek().typ == tokOr {
children := []dslNode{left}
for p.peek().typ == tokOr {
p.next() // consume |
right, err := p.parseAndExpr()
if err != nil {
return nil, err
}
children = append(children, right)
}
return orNode{children: children}, nil
}
return left, nil
}
func (p *parser) parseAndExpr() (dslNode, error) {
left, err := p.parseAtom()
if err != nil {
return nil, err
}
if p.peek().typ == tokAnd {
children := []dslNode{left}
for p.peek().typ == tokAnd {
p.next() // consume &
right, err := p.parseAtom()
if err != nil {
return nil, err
}
children = append(children, right)
}
return andNode{children: children}, nil
}
return left, nil
}
func (p *parser) parseAtom() (dslNode, error) {
t := p.peek()
switch t.typ {
case tokIdent:
p.next()
return leafNode{name: t.val}, nil
case tokRef:
p.next()
return refNode{setName: t.val}, nil
case tokLParen:
p.next() // consume (
node, err := p.parseExpr()
if err != nil {
return nil, err
}
if _, err := p.expect(tokRParen); err != nil {
return nil, fmt.Errorf("missing closing parenthesis")
}
return node, nil
default:
return nil, fmt.Errorf("unexpected token %q", t.val)
}
}
// expand walks the AST and produces all combinations.
// resolvedRefs contains previously expanded sets for +ref resolution.
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
switch n := node.(type) {
case leafNode:
return [][]string{{n.name}}, nil
case refNode:
expanded, ok := resolvedRefs[n.setName]
if !ok {
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
}
// Return a copy
result := make([][]string, len(expanded))
for i, combo := range expanded {
result[i] = make([]string, len(combo))
copy(result[i], combo)
}
return result, nil
case orNode:
// Union of all children's expansions
var result [][]string
for _, child := range n.children {
childResult, err := expand(child, resolvedRefs)
if err != nil {
return nil, err
}
result = append(result, childResult...)
if len(result) > maxDSLExpansions {
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
}
}
return result, nil
case andNode:
// Cartesian product across children
result := [][]string{{}} // start with one empty combo
for _, child := range n.children {
childResult, err := expand(child, resolvedRefs)
if err != nil {
return nil, err
}
result, err = cartesianProduct(result, childResult, maxDSLExpansions)
if err != nil {
return nil, err
}
}
return result, nil
default:
return nil, fmt.Errorf("unknown node type %T", node)
}
}
// cartesianProduct computes the cartesian product of two sets of combinations.
// It returns an error if the product would exceed cap.
func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
if int64(len(left))*int64(len(right)) > int64(cap) {
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
}
result := make([][]string, 0, len(left)*len(right))
for _, l := range left {
for _, r := range right {
combo := make([]string, 0, len(l)+len(r))
combo = append(combo, l...)
combo = append(combo, r...)
result = append(result, combo)
}
}
return result, nil
}
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
// resolvedRefs contains previously expanded sets for +ref inlining.
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
dsl = strings.TrimSpace(dsl)
if dsl == "" {
return nil, fmt.Errorf("empty DSL expression")
}
tokens, err := tokenize(dsl)
if err != nil {
return nil, fmt.Errorf("tokenize: %w", err)
}
tree, err := parse(tokens)
if err != nil {
return nil, fmt.Errorf("parse: %w", err)
}
result, err := expand(tree, resolvedRefs)
if err != nil {
return nil, err
}
// Deduplicate models within each combination and sort for consistency
for i, combo := range result {
result[i] = dedupAndSort(combo)
}
return result, nil
}
// dedupAndSort removes duplicate entries and sorts alphabetically.
func dedupAndSort(items []string) []string {
seen := make(map[string]bool, len(items))
var unique []string
for _, item := range items {
if !seen[item] {
seen[item] = true
unique = append(unique, item)
}
}
sort.Strings(unique)
return unique
}
// extractRefs scans a DSL string for +ref tokens without full parsing.
// Used for building the dependency graph for topological sorting.
func extractRefs(dsl string) ([]string, error) {
tokens, err := tokenize(dsl)
if err != nil {
return nil, err
}
var refs []string
seen := make(map[string]bool)
for _, t := range tokens {
if t.typ == tokRef && !seen[t.val] {
seen[t.val] = true
refs = append(refs, t.val)
}
}
return refs, nil
}
+300
View File
@@ -0,0 +1,300 @@
package config
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDSL_Tokenize(t *testing.T) {
tests := []struct {
name string
input string
expect []token
errMsg string
}{
{
name: "single identifier",
input: "abc",
expect: []token{
{tokIdent, "abc"},
{tokEOF, ""},
},
},
{
name: "identifier with hyphens and dots",
input: "model-name.v2",
expect: []token{
{tokIdent, "model-name.v2"},
{tokEOF, ""},
},
},
{
name: "and expression",
input: "a & b",
expect: []token{
{tokIdent, "a"},
{tokAnd, "&"},
{tokIdent, "b"},
{tokEOF, ""},
},
},
{
name: "or expression",
input: "a | b",
expect: []token{
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokEOF, ""},
},
},
{
name: "parentheses",
input: "(a | b) & c",
expect: []token{
{tokLParen, "("},
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokRParen, ")"},
{tokAnd, "&"},
{tokIdent, "c"},
{tokEOF, ""},
},
},
{
name: "ref token",
input: "+llms & v",
expect: []token{
{tokRef, "llms"},
{tokAnd, "&"},
{tokIdent, "v"},
{tokEOF, ""},
},
},
{
name: "no whitespace",
input: "(a|b)&c",
expect: []token{
{tokLParen, "("},
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokRParen, ")"},
{tokAnd, "&"},
{tokIdent, "c"},
{tokEOF, ""},
},
},
{
name: "empty ref",
input: "+",
errMsg: "expected set name after '+'",
},
{
name: "invalid character",
input: "a @ b",
errMsg: "unexpected character",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens, err := tokenize(tt.input)
if tt.errMsg != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expect, tokens)
}
})
}
}
func TestDSL_ParseAndExpand(t *testing.T) {
tests := []struct {
name string
dsl string
refs map[string][][]string
expect [][]string
errMsg string
}{
{
name: "single model",
dsl: "L",
expect: [][]string{{"L"}},
},
{
name: "two models with AND",
dsl: "a & b",
expect: [][]string{{"a", "b"}},
},
{
name: "two models with OR",
dsl: "a | b",
expect: [][]string{{"a"}, {"b"}},
},
{
name: "three models with OR",
dsl: "a | b | c",
expect: [][]string{{"a"}, {"b"}, {"c"}},
},
{
name: "cartesian product (a|b) & (c|d)",
dsl: "(a | b) & (c | d)",
expect: [][]string{
{"a", "c"},
{"a", "d"},
{"b", "c"},
{"b", "d"},
},
},
{
name: "three-way AND",
dsl: "a & b & c",
expect: [][]string{
{"a", "b", "c"},
},
},
{
name: "(g | q | m) & v",
dsl: "(g | q | m) & v",
expect: [][]string{
{"g", "v"},
{"q", "v"},
{"m", "v"},
},
},
{
name: "(g | q) & v & e",
dsl: "(g | q) & v & e",
expect: [][]string{
{"e", "g", "v"},
{"e", "q", "v"},
},
},
{
name: "precedence: a | b & c means a | (b & c)",
dsl: "a | b & c",
expect: [][]string{
{"a"},
{"b", "c"},
},
},
{
name: "+ref inlining",
dsl: "+llms & v",
refs: map[string][][]string{
"llms": {{"g"}, {"q"}, {"m"}},
},
expect: [][]string{
{"g", "v"},
{"q", "v"},
{"m", "v"},
},
},
{
name: "+ref chained",
dsl: "+with_tts & e",
refs: map[string][][]string{
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
},
expect: [][]string{
{"e", "g", "v"},
{"e", "q", "v"},
{"e", "m", "v"},
},
},
{
name: "dedup within combination",
dsl: "a & a",
expect: [][]string{
{"a"},
},
},
{
name: "empty expression",
dsl: "",
errMsg: "empty DSL expression",
},
{
name: "unmatched open paren",
dsl: "(a | b",
errMsg: "missing closing parenthesis",
},
{
name: "unmatched close paren",
dsl: "a | b)",
errMsg: "unexpected token",
},
{
name: "unknown ref",
dsl: "+unknown",
errMsg: "unknown set reference +unknown",
},
{
name: "empty parens",
dsl: "()",
errMsg: "unexpected token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
refs := tt.refs
if refs == nil {
refs = map[string][][]string{}
}
result, err := ParseAndExpandDSL(tt.dsl, refs)
if tt.errMsg != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expect, result)
}
})
}
}
func TestDSL_ExpansionCap(t *testing.T) {
// Build an expression that would exceed 1000 combinations:
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
var aItems, bItems []string
for i := 0; i < 32; i++ {
aItems = append(aItems, fmt.Sprintf("a%d", i))
bItems = append(bItems, fmt.Sprintf("b%d", i))
}
dsl := fmt.Sprintf("(%s) & (%s)",
join(aItems, " | "),
join(bItems, " | "),
)
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeded")
}
func TestDSL_ExtractRefs(t *testing.T) {
refs, err := extractRefs("+llms & v & +other")
require.NoError(t, err)
assert.Equal(t, []string{"llms", "other"}, refs)
refs, err = extractRefs("a & b")
require.NoError(t, err)
assert.Empty(t, refs)
}
func join(items []string, sep string) string {
result := ""
for i, item := range items {
if i > 0 {
result += sep
}
result += item
}
return result
}
+305
View File
@@ -0,0 +1,305 @@
package config
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeModels(names ...string) map[string]ModelConfig {
m := make(map[string]ModelConfig)
for _, name := range names {
m[name] = ModelConfig{Cmd: "echo " + name}
}
return m
}
func TestValidateMatrix_Basic(t *testing.T) {
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
matrix := MatrixConfig{
Var: map[string]string{
"g": "gemma",
"q": "qwen",
"m": "mistral",
"v": "voxtral",
"L": "llama70B",
},
EvictCosts: map[string]int{
"L": 30,
"v": 50,
},
Sets: OrderedSets{
{Name: "standard", DSL: "(g | q | m) & v"},
{Name: "full", DSL: "L"},
},
}
expanded, err := ValidateMatrix(matrix, models)
require.NoError(t, err)
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
// full expands to [llama70B]
assert.Len(t, expanded, 4)
assert.Equal(t, "standard", expanded[0].SetName)
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
assert.Equal(t, "standard", expanded[1].SetName)
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
assert.Equal(t, "standard", expanded[2].SetName)
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
assert.Equal(t, "full", expanded[3].SetName)
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
}
func TestValidateMatrix_WithRef(t *testing.T) {
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
matrix := MatrixConfig{
Var: map[string]string{
"g": "gemma",
"q": "qwen",
"m": "mistral",
"v": "voxtral",
"e": "reranker",
},
Sets: OrderedSets{
{Name: "llms", DSL: "g | q | m"},
{Name: "with_tts", DSL: "+llms & v"},
{Name: "mega", DSL: "+with_tts & e"},
},
}
expanded, err := ValidateMatrix(matrix, models)
require.NoError(t, err)
// llms: [gemma], [qwen], [mistral]
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
assert.Len(t, expanded, 9)
// Check mega entries
megaEntries := filterBySetName(expanded, "mega")
assert.Len(t, megaEntries, 3)
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
}
func TestValidateMatrix_MapIDRequired(t *testing.T) {
// DSL cannot use real model names directly — must use var IDs
models := makeModels("gemma", "voxtral")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "combo", DSL: "g & voxtral"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown var ID")
}
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
models := makeModels("gemma")
tests := []struct {
name string
alias string
errMsg string
}{
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{tt.alias: "gemma"},
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
})
}
}
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"x": "nonexistent"},
Sets: OrderedSets{{Name: "s", DSL: "x"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown model")
}
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
models := makeModels("gemma")
t.Run("zero cost", func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"g": 0},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "positive integer")
})
t.Run("negative cost", func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"g": -1},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "positive integer")
})
t.Run("unknown var ID in evict_costs", func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"unknown": 5},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown var ID")
})
}
func TestValidateMatrix_CycleDetection(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "a", DSL: "+b"},
{Name: "b", DSL: "+a"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "circular reference")
}
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "a", DSL: "+nonexistent"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "references undefined set")
}
func TestValidateMatrix_NoSets(t *testing.T) {
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one set")
}
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "s", DSL: "g & nonexistent"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown var ID")
}
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
mc := &MatrixConfig{
Var: map[string]string{
"g": "gemma",
"L": "llama70B",
},
EvictCosts: map[string]int{
"L": 30,
"g": 5,
},
}
costs := mc.ResolvedEvictCosts()
assert.Equal(t, 30, costs["llama70B"])
assert.Equal(t, 5, costs["gemma"])
}
func TestValidateMatrix_ConfigXOR(t *testing.T) {
// groups and matrix both defined
yaml := `
models:
model1:
cmd: echo model1
proxy: http://localhost:8080
groups:
group1:
members:
- model1
matrix:
sets:
s: "model1"
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot use both")
}
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
yaml := `
models:
gemma:
cmd: echo gemma
proxy: http://localhost:8080
qwen:
cmd: echo qwen
proxy: http://localhost:8081
matrix:
vars:
g: gemma
q: qwen
sets:
combo: "g | q"
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.NotNil(t, cfg.Matrix)
assert.Len(t, cfg.ExpandedSets, 2)
// Groups should be empty when matrix is used
assert.Empty(t, cfg.Groups)
}
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
var result []ExpandedSet
for _, s := range sets {
if s.SetName == name {
result = append(result, s)
}
}
return result
}
+11 -6
View File
@@ -10,12 +10,14 @@ const (
) )
// TimeoutsConfig holds timeout settings for proxy connections // TimeoutsConfig holds timeout settings for proxy connections
// 0 = no timeout
type TimeoutsConfig struct { type TimeoutsConfig struct {
Connect int `yaml:"connect"` // seconds, 0 = no timeout (not recommended) Connect int `yaml:"connect"`
ResponseHeader int `yaml:"responseHeader"` // seconds, 0 = no timeout (not recommended) KeepAlive int `yaml:"keepalive"`
TLSHandshake int `yaml:"tlsHandshake"` // seconds, 0 = no timeout (not recommended) ResponseHeader int `yaml:"responseHeader"`
ExpectContinue int `yaml:"expectContinue"` // seconds, 0 = no timeout (not recommended) TLSHandshake int `yaml:"tlsHandshake"`
IdleConn int `yaml:"idleConn"` // seconds, 0 = no timeout (not recommended) ExpectContinue int `yaml:"expectContinue"`
IdleConn int `yaml:"idleConn"`
} }
type ModelConfig struct { type ModelConfig struct {
@@ -69,9 +71,12 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
ConcurrencyLimit: 0, ConcurrencyLimit: 0,
Name: "", Name: "",
Description: "", Description: "",
// matches http.DefaultTransport
Timeouts: TimeoutsConfig{ Timeouts: TimeoutsConfig{
Connect: 30, Connect: 30,
ResponseHeader: 60, KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10, TLSHandshake: 10,
ExpectContinue: 1, ExpectContinue: 1,
IdleConn: 90, IdleConn: 90,
+4
View File
@@ -24,8 +24,12 @@ func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
ApiKey: "", ApiKey: "",
Models: []string{}, Models: []string{},
Filters: Filters{}, Filters: Filters{},
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
// to match the pre PR #619 functionality
Timeouts: TimeoutsConfig{ Timeouts: TimeoutsConfig{
Connect: 30, Connect: 30,
KeepAlive: 30,
ResponseHeader: 60, ResponseHeader: 60,
TLSHandshake: 10, TLSHandshake: 10,
ExpectContinue: 1, ExpectContinue: 1,
+85
View File
@@ -0,0 +1,85 @@
// Package configwatcher provides a simple cross-platform file watcher based
// on os.Stat polling. It works correctly inside Docker containers where the
// config file is bind-mounted as an individual file, and for k8s ConfigMap
// projections (which present the file as a symlink to an atomically swapped
// target) — both cases where inotify-based watchers are unreliable.
package configwatcher
import (
"context"
"errors"
"io/fs"
"log"
"os"
"time"
)
const DefaultInterval = 2 * time.Second
type Watcher struct {
Path string
Interval time.Duration
OnChange func()
}
type snapshot struct {
exists bool
modTime time.Time
size int64
}
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
// OnChange whenever the file's modification time or size changes, or when
// the file reappears after being missing. The baseline poll establishes
// initial state and does not fire OnChange.
func (w *Watcher) Run(ctx context.Context) {
interval := w.Interval
if interval <= 0 {
interval = DefaultInterval
}
prev := stat(w.Path)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cur := stat(w.Path)
if changed(prev, cur) && w.OnChange != nil {
w.OnChange()
}
prev = cur
}
}
}
func stat(path string) snapshot {
fi, err := os.Stat(path)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Printf("configwatcher: stat %s: %v", path, err)
}
return snapshot{}
}
return snapshot{
exists: true,
modTime: fi.ModTime(),
size: fi.Size(),
}
}
func changed(prev, cur snapshot) bool {
// Present → missing: stay quiet (likely a transient rename-style write).
// Missing → present: fire so we reload as soon as the file comes back.
if !cur.exists {
return false
}
if !prev.exists {
return true
}
return !prev.modTime.Equal(cur.modTime) || prev.size != cur.size
}
+191
View File
@@ -0,0 +1,191 @@
package configwatcher
import (
"context"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testInterval = 25 * time.Millisecond
// startWatcher launches w.Run in a goroutine and returns a function that
// cancels the context and waits for Run to return.
func startWatcher(t *testing.T, w *Watcher) func() {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
return func() {
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("watcher did not stop within 2s of cancel")
}
}
}
// waitForCount blocks until counter reaches want or timeout elapses.
func waitForCount(t *testing.T, counter *int64, want int64, timeout time.Duration) bool {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if atomic.LoadInt64(counter) >= want {
return true
}
time.Sleep(5 * time.Millisecond)
}
return false
}
func TestWatcher_NoFireOnBaseline(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 5)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
}
func TestWatcher_DetectsModTimeChange(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
// Force a known baseline mtime.
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
require.NoError(t, os.Chtimes(path, base, base))
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
// Let the baseline settle.
time.Sleep(testInterval * 2)
// Bump mtime well above the baseline so low-resolution filesystems still notice.
require.NoError(t, os.Chtimes(path, base.Add(10*time.Second), base.Add(10*time.Second)))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
}
func TestWatcher_DetectsSizeChangeWithSameModTime(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
fi, err := os.Stat(path)
require.NoError(t, err)
originalMtime := fi.ModTime()
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
require.NoError(t, os.WriteFile(path, []byte("aaaaa"), 0o644))
// Reset mtime back to the original so size is the only signal.
require.NoError(t, os.Chtimes(path, originalMtime, originalMtime))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire on size change")
}
func TestWatcher_SymlinkTargetSwap(t *testing.T) {
dir := t.TempDir()
targetA := filepath.Join(dir, "targetA")
targetB := filepath.Join(dir, "targetB")
link := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(targetA, []byte("AAAA"), 0o644))
require.NoError(t, os.WriteFile(targetB, []byte("BBBBBBBB"), 0o644))
if err := os.Symlink(targetA, link); err != nil {
if runtime.GOOS == "windows" {
t.Skipf("symlink creation requires privilege on Windows: %v", err)
}
t.Fatalf("os.Symlink: %v", err)
}
var n int64
stop := startWatcher(t, &Watcher{
Path: link,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
// Atomic symlink swap (k8s ConfigMap pattern): create new symlink at a
// temp name, then rename over the existing one.
tmpLink := filepath.Join(dir, "config.yaml.tmp")
require.NoError(t, os.Symlink(targetB, tmpLink))
require.NoError(t, os.Rename(tmpLink, link))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after symlink target swap")
}
func TestWatcher_FileMissingThenReturns(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
require.NoError(t, os.Remove(path))
time.Sleep(testInterval * 3)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "removal alone must not fire")
require.NoError(t, os.WriteFile(path, []byte("b"), 0o644))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when file returns")
}
func TestWatcher_ContextCancelStopsRun(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
w := &Watcher{Path: path, Interval: testInterval}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() { w.Run(ctx); close(done) }()
time.Sleep(testInterval * 2)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run did not return within 2s of cancel")
}
}
+1 -1
View File
@@ -6,7 +6,7 @@ const ProcessStateChangeEventID = 0x01
const ChatCompletionStatsEventID = 0x02 const ChatCompletionStatsEventID = 0x02
const ConfigFileChangedEventID = 0x03 const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04 const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05 const ActivityLogEventID = 0x05
const ModelPreloadedEventID = 0x06 const ModelPreloadedEventID = 0x06
const InFlightRequestsEventID = 0x07 const InFlightRequestsEventID = 0x07
+202
View File
@@ -1,15 +1,22 @@
package proxy package proxy
import ( import (
"encoding/json"
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"sync" "sync"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -66,6 +73,16 @@ func getTestPort() int {
return port return port
} }
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig { func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
} }
@@ -88,3 +105,188 @@ proxy: "http://127.0.0.1:%d"
return cfg return cfg
} }
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
// model IDs to their respond strings; if a model ID is not in the map, the model
// ID itself is used.
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
for _, pg := range pm.processGroups {
for modelID, process := range pg.processes {
respond := modelID
if r, ok := modelResponses[modelID]; ok {
respond = r
}
process.testHandler = newTestHandler(respond)
}
}
}
// newTestHandler returns an http.Handler that mimics simple-responder's API.
// It supports the endpoints that routing tests depend on, without launching
// any subprocess or binding any port.
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
isStreaming := r.URL.Query().Get("stream") == "true"
if wait := r.URL.Query().Get("wait"); wait != "" {
if d, err := time.ParseDuration(wait); err == nil {
time.Sleep(d)
}
}
if isStreaming {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher := w.(http.Flusher)
for i := 0; i < 10; i++ {
data, _ := json.Marshal(map[string]any{
"created": time.Now().Unix(),
"choices": []map[string]any{
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
},
})
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
flusher.Flush()
}
finalData, _ := json.Marshal(map[string]any{
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
"timings": map[string]any{
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
"predicted_ms": 17, "predicted_per_second": 10,
},
})
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
flusher.Flush()
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
flusher.Flush()
} else {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"h_content_length": r.Header.Get("Content-Length"),
"request_body": string(bodyBytes),
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
"timings": map[string]any{
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
"predicted_ms": 17, "predicted_per_second": 10,
},
})
}
})
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
if modelName != respond {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
return
}
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
})
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
})
})
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
})
})
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
model := r.FormValue("model")
if model == "" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
return
}
file, _, err := r.FormFile("file")
if err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
fileBytes, _ := io.ReadAll(file)
file.Close()
json.NewEncoder(w).Encode(map[string]any{
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
"model": model,
"h_content_type": r.Header.Get("Content-Type"),
"h_content_length": r.Header.Get("Content-Length"),
})
})
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
model := r.URL.Query().Get("model")
json.NewEncoder(w).Encode(map[string]any{
"voices": []string{"voice1"}, "model": model,
})
})
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprint(w, respond)
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
})
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
json.NewEncoder(w).Encode(map[string]any{
"model": modelName, "images": []string{},
})
})
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
json.NewEncoder(w).Encode(map[string]any{
"model": modelName, "images": []string{},
})
})
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"loras": []string{},
})
})
return mux
}
+329
View File
@@ -0,0 +1,329 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sort"
"sync"
"github.com/mostlygeek/llama-swap/proxy/config"
)
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
// It is safe for concurrent reads after construction.
type MatrixSolver struct {
expandedSets []config.ExpandedSet // all valid model combinations
evictCosts map[string]int // real model name -> eviction cost (default 1)
modelToSets map[string][]int // model name -> indices into expandedSets
}
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
modelToSets := make(map[string][]int)
for i, es := range expandedSets {
for _, model := range es.Models {
modelToSets[model] = append(modelToSets[model], i)
}
}
return &MatrixSolver{
expandedSets: expandedSets,
evictCosts: evictCosts,
modelToSets: modelToSets,
}
}
// SolveResult describes what the solver decided.
type SolveResult struct {
Evict []string // running models that must be stopped
TargetSet []string // the chosen set of models (for informational purposes)
SetName string // name of the chosen set
DSL string // original DSL expression for the chosen set
TotalCost int // total eviction cost
}
// Solve determines which models to evict when a model is requested.
//
// Algorithm:
// 1. If requestedModel is already running, no eviction needed.
// 2. Find all sets containing requestedModel.
// 3. If no sets found, the model runs alone; evict all running models.
// 4. For each candidate set, compute cost = sum of evict_costs for running
// models NOT in that set.
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
// 6. Return models to evict and the chosen set.
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
// If already running, nothing to do (but fill in set info for logging)
if slices.Contains(runningModels, requestedModel) {
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
return SolveResult{
TargetSet: runningModels,
SetName: setName,
DSL: dsl,
}, nil
}
candidateIndices := s.modelToSets[requestedModel]
// Model not in any set: runs alone, evict everything
if len(candidateIndices) == 0 {
evict := make([]string, len(runningModels))
copy(evict, runningModels)
return SolveResult{
Evict: evict,
TargetSet: []string{requestedModel},
}, nil
}
// Find the cheapest candidate set
bestCost := -1
bestIdx := -1
for _, idx := range candidateIndices {
setModels := s.expandedSets[idx].Models
cost := 0
for _, running := range runningModels {
if !slices.Contains(setModels, running) {
cost += s.evictCost(running)
}
}
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
bestCost = cost
bestIdx = idx
}
}
// Determine which running models to evict
chosen := s.expandedSets[bestIdx]
var evict []string
for _, running := range runningModels {
if !slices.Contains(chosen.Models, running) {
evict = append(evict, running)
}
}
return SolveResult{
Evict: evict,
TargetSet: chosen.Models,
SetName: chosen.SetName,
DSL: chosen.DSL,
TotalCost: bestCost,
}, nil
}
// findMatchingSet finds the expanded set that contains all running models.
// Returns the set name and DSL, or empty strings if no match.
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
for _, idx := range s.modelToSets[requestedModel] {
set := s.expandedSets[idx]
allInSet := true
for _, m := range runningModels {
if !slices.Contains(set.Models, m) {
allInSet = false
break
}
}
if allInSet {
return set.SetName, set.DSL
}
}
return "", ""
}
func (s *MatrixSolver) evictCost(model string) int {
if cost, ok := s.evictCosts[model]; ok {
return cost
}
return 1
}
// Matrix manages processes using solver-based swap logic.
type Matrix struct {
sync.Mutex
solver *MatrixSolver
processes map[string]*Process // all processes keyed by real model name
config config.Config
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// inflight tracks ProxyRequest calls that have released m.Lock but may
// not yet have incremented Process.inFlightRequests. A concurrent
// request that needs to evict models waits for inflight to drain under
// m.Lock before stopping anything. Without this, a request that
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
// races with Stop()'s Wait() and can be killed mid-request.
inflight sync.WaitGroup
// testDelayFastPath is a test-only hook invoked in the no-eviction path
// after m.Lock is released but before the request is dispatched to
// Process.ProxyRequest. Tests use it to park a request at the exact
// race window to deterministically reproduce the race.
testDelayFastPath func()
}
// NewMatrix creates a Matrix from config. It creates a Process for every
// model defined in the config (any model can run alone even if not in a set).
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
processes := make(map[string]*Process)
for modelID, modelConfig := range cfg.Models {
processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
processes[modelID] = process
}
evictCosts := cfg.Matrix.ResolvedEvictCosts()
return &Matrix{
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
processes: processes,
config: cfg,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
}
}
// ProxyRequest handles the swap logic and proxies the request to the model.
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
process, ok := m.processes[modelID]
if !ok {
return fmt.Errorf("model %s not found in matrix", modelID)
}
m.Lock()
running := m.runningModels()
result, err := m.solver.Solve(modelID, running)
if err != nil {
m.Unlock()
return fmt.Errorf("matrix solver error: %w", err)
}
// Log solver decision
if len(result.Evict) > 0 {
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
} else if len(running) == 0 {
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
} else {
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
}
// Evict models that need to be stopped
if len(result.Evict) > 0 {
// Wait for any in-flight ProxyRequest calls to register on their
// Process before stopping anything. Without this, a request that
// released m.Lock but has not yet incremented
// Process.inFlightRequests races with Stop() and can be killed
// mid-request.
m.inflight.Wait()
var wg sync.WaitGroup
for _, evictModel := range result.Evict {
if p, exists := m.processes[evictModel]; exists {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
p.Stop()
}(p)
}
}
wg.Wait()
}
// Register this request in inflight before releasing m.Lock so a
// concurrent eviction will wait for it to complete.
m.inflight.Add(1)
defer m.inflight.Done()
isFastPath := len(result.Evict) == 0
m.Unlock()
if isFastPath && m.testDelayFastPath != nil {
m.testDelayFastPath()
}
// Proxy the request (Process handles on-demand start)
process.ProxyRequest(w, r)
return nil
}
// StopProcesses stops all running processes.
func (m *Matrix) StopProcesses(strategy StopStrategy) {
m.Lock()
defer m.Unlock()
var wg sync.WaitGroup
for _, process := range m.processes {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
switch strategy {
case StopImmediately:
p.StopImmediately()
default:
p.Stop()
}
}(process)
}
wg.Wait()
}
// StopProcess stops a single process by model ID.
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
process, ok := m.processes[modelID]
if !ok {
return fmt.Errorf("process not found for %s", modelID)
}
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
return nil
}
// Shutdown shuts down all processes.
func (m *Matrix) Shutdown() {
var wg sync.WaitGroup
for _, process := range m.processes {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
p.Shutdown()
}(process)
}
wg.Wait()
}
// RunningModels returns model names currently in an active (non-stopped) state.
func (m *Matrix) RunningModels() []string {
m.Lock()
defer m.Unlock()
return m.runningModels()
}
// runningModels returns running model names (caller must hold lock).
func (m *Matrix) runningModels() []string {
var running []string
for id, process := range m.processes {
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
running = append(running, id)
}
}
sort.Strings(running)
return running
}
// GetProcess returns the Process for a model.
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
p, ok := m.processes[modelID]
return p, ok
}
// HasModel returns true if the model is managed by this matrix.
func (m *Matrix) HasModel(modelID string) bool {
_, ok := m.processes[modelID]
return ok
}
+349
View File
@@ -0,0 +1,349 @@
package proxy
import (
"net/http"
"net/http/httptest"
"runtime"
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper to build expanded sets for solver tests
func makeExpandedSets(sets ...struct {
name string
models []string
}) []config.ExpandedSet {
var result []config.ExpandedSet
for _, s := range sets {
result = append(result, config.ExpandedSet{
SetName: s.name,
Models: s.models,
})
}
return result
}
func es(name string, models ...string) struct {
name string
models []string
} {
return struct {
name string
models []string
}{name, models}
}
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("a", []string{"a"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a"}, result.TargetSet)
assert.Equal(t, "s1", result.SetName)
}
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
// Model "c" not in any set
result, err := solver.Solve("c", []string{"a", "b"})
require.NoError(t, err)
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
assert.Equal(t, []string{"c"}, result.TargetSet)
}
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("c", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"c"}, result.TargetSet)
}
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
// Set: [a, b]. Request a when b and c are running.
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("a", []string{"b", "c"})
require.NoError(t, err)
// c is not in the set, so it gets evicted. b is in the set, so it stays.
assert.Equal(t, []string{"c"}, result.Evict)
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
}
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
// Two sets containing model "a":
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "a", "v"),
es("s2", "a", "L"),
),
map[string]int{"v": 50, "L": 30},
)
// v is running. Switching to a:
// s1 cost: v is in s1, so 0
// s2 cost: v is NOT in s2, so 50
// => pick s1
result, err := solver.Solve("a", []string{"v"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
// L is running. Switching to a:
// s1 cost: L is NOT in s1, so 30
// s2 cost: L is in s2, so 0
// => pick s2
result, err = solver.Solve("a", []string{"L"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
}
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
// Two sets with identical cost. Definition order should win.
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "a", "x"),
es("s2", "a", "y"),
),
nil,
)
// Nothing running, both sets cost 0. s1 is first.
result, err := solver.Solve("a", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
}
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
// Model "v" costs 50 to evict, "m" costs 1 (default).
// Sets: [g,v], [g,m]
// Running: v, m. Request g.
// s1=[g,v]: evict m (cost 1), keep v
// s2=[g,m]: evict v (cost 50), keep m
// => pick s1
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "g", "v"),
es("s2", "g", "m"),
),
map[string]int{"v": 50},
)
result, err := solver.Solve("g", []string{"v", "m"})
require.NoError(t, err)
assert.Equal(t, []string{"m"}, result.Evict)
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
}
func TestMatrixSolver_NothingRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "g", "v"),
es("s2", "q", "v"),
),
nil,
)
result, err := solver.Solve("g", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
}
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
// cannot stop a process while an in-flight ProxyRequest for that process is
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
// matrix-level inflight tracking, the eviction's Stop() races with the
// pending request and kills it mid-start.
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
cfg := config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
ExpandedSets: []config.ExpandedSet{
{SetName: "s1", Models: []string{"model1"}},
{SetName: "s2", Models: []string{"model2"}},
},
Matrix: &config.MatrixConfig{},
}
m := NewMatrix(cfg, testLogger, testLogger)
defer m.StopProcesses(StopImmediately)
// Bypass real subprocesses so the test is fast and deterministic.
m.processes["model1"].testHandler = newTestHandler("model1")
m.processes["model2"].testHandler = newTestHandler("model2")
// Prime: run a request through model1 so it reaches StateReady and
// subsequent requests take the no-eviction path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
// Install fast-path hook that signals arrival and waits for release.
// This parks R2 at the race window — after m.Lock is released but
// before Process.inFlightRequests.Add(1).
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
m.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
// R2: no-eviction request for model1. Will pause at the hook.
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, m.ProxyRequest("model1", w2, req))
}()
// Deterministically wait for R2 to reach the race window.
<-r2Reached
// R3: request for model2 which requires evicting model1. Must wait for
// R2 to finish before touching model1.
r3Done := make(chan struct{})
w3 := httptest.NewRecorder()
go func() {
defer close(r3Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, m.ProxyRequest("model2", w3, req))
}()
// Spin until R3 has acquired m.Lock and entered the eviction path. In
// the fixed code, R3 then blocks on m.inflight.Wait() while still
// holding the lock, so TryLock keeps failing.
for m.TryLock() {
m.Unlock()
runtime.Gosched()
}
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
// state. In the fixed code R3 is blocked and nothing changes; in the
// buggy code R3 will Stop() model1 and start model2 within microseconds.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if m.processes["model1"].CurrentState() != StateReady ||
m.processes["model2"].CurrentState() != StateStopped {
break
}
done := false
select {
case <-r3Done:
done = true
default:
}
if done {
break
}
runtime.Gosched()
}
// Invariant: R3 must be blocked while R2 is still in flight.
select {
case <-r3Done:
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
default:
}
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
"model1 must stay Ready while an in-flight request is pending")
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
"model2 must not be started until R2 finishes and model1 is evicted")
// Release R2 and let both requests finish.
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
assert.Equal(t, http.StatusOK, w3.Code)
assert.Contains(t, w3.Body.String(), "model2")
}
func TestMatrixSolver_FullScenario(t *testing.T) {
// Simulates the example config:
// standard: [g,v], [q,v], [m,v]
// with_rerank: [g,v,e], [q,v,e]
// creative: [g,sd], [q,sd]
// full: [L]
solver := NewMatrixSolver(
makeExpandedSets(
es("standard", "g", "v"),
es("standard", "q", "v"),
es("standard", "m", "v"),
es("with_rerank", "e", "g", "v"),
es("with_rerank", "e", "q", "v"),
es("creative", "g", "sd"),
es("creative", "q", "sd"),
es("full", "L"),
),
map[string]int{"v": 50, "L": 30, "whisper": 10},
)
// Running: g, v. Request q.
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
// => tie, pick first by definition order = standard[q,v]
result, err := solver.Solve("q", []string{"g", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"g"}, result.Evict)
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
// Running: g, v. Request L.
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
// Only one set contains L, so pick it.
result, err = solver.Solve("L", []string{"g", "v"})
require.NoError(t, err)
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
assert.Equal(t, []string{"L"}, result.TargetSet)
// Running: g, v. Request sd.
// creative[g,sd]: evict v (cost 50). Total: 50.
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
// => pick creative[g,sd]
result, err = solver.Solve("sd", []string{"g", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"v"}, result.Evict)
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
// Running: q, v, e. Request g.
// standard[g,v]: evict q (1) + e (1). Total: 2.
// with_rerank[g,v,e]: evict q (1). Total: 1.
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
// => pick with_rerank[g,v,e]
result, err = solver.Solve("g", []string{"e", "q", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"q"}, result.Evict)
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
}
+239 -121
View File
@@ -12,23 +12,85 @@ import (
"sync" "sync"
"time" "time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/cache"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
// TokenMetrics represents parsed token statistics from llama-server logs // zstdEncOptions are the shared zstd encoder options for maximum compression.
var zstdEncOptions = []zstd.EOption{
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
}
// zstdDecOptions are the shared zstd decoder options.
var zstdDecOptions = []zstd.DOption{}
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
var zstdEncPool = &sync.Pool{
New: func() interface{} {
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
return enc
},
}
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
var zstdDecPool = &sync.Pool{
New: func() interface{} {
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
return dec
},
}
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
// Returns compressed bytes and the original CBOR byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
cborBytes, err := cbor.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
zenc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(zenc)
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
}
// decompressCapture decompresses zstd-compressed CBOR and unmarshals it into a ReqRespCapture.
func decompressCapture(data []byte) (*ReqRespCapture, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
cborBytes, err := dec.DecodeAll(data, nil)
if err != nil {
return nil, fmt.Errorf("decompress capture: %w", err)
}
var capture ReqRespCapture
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
return nil, fmt.Errorf("unmarshal capture: %w", err)
}
return &capture, nil
}
// TokenMetrics holds token usage and performance metrics
type TokenMetrics struct { type TokenMetrics struct {
ID int `json:"id"` CachedTokens int `json:"cache_tokens"`
Timestamp time.Time `json:"timestamp"` InputTokens int `json:"input_tokens"`
Model string `json:"model"` OutputTokens int `json:"output_tokens"`
CachedTokens int `json:"cache_tokens"` PromptPerSecond float64 `json:"prompt_per_second"`
InputTokens int `json:"input_tokens"` TokensPerSecond float64 `json:"tokens_per_second"`
OutputTokens int `json:"output_tokens"` }
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"` // ActivityLogEntry represents parsed token statistics from llama-server logs
DurationMs int `json:"duration_ms"` type ActivityLogEntry struct {
HasCapture bool `json:"has_capture"` ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
ReqPath string `json:"req_path"`
RespContentType string `json:"resp_content_type"`
RespStatusCode int `json:"resp_status_code"`
Tokens TokenMetrics `json:"tokens"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
} }
type ReqRespCapture struct { type ReqRespCapture struct {
@@ -40,60 +102,45 @@ type ReqRespCapture struct {
RespBody []byte `json:"resp_body"` RespBody []byte `json:"resp_body"`
} }
// Size returns the approximate memory usage of this capture in bytes // ActivityLogEvent represents a token metrics event
func (c *ReqRespCapture) Size() int { type ActivityLogEvent struct {
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody) Metrics ActivityLogEntry
for k, v := range c.ReqHeaders {
size += len(k) + len(v)
}
for k, v := range c.RespHeaders {
size += len(k) + len(v)
}
return size
} }
// TokenMetricsEvent represents a token metrics event func (e ActivityLogEvent) Type() uint32 {
type TokenMetricsEvent struct { return ActivityLogEventID // defined in events.go
Metrics TokenMetrics
}
func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
} }
// metricsMonitor parses llama-server output for token statistics // metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct { type metricsMonitor struct {
mu sync.RWMutex mu sync.RWMutex
metrics []TokenMetrics metrics []ActivityLogEntry
maxMetrics int maxMetrics int
nextID int nextID int
logger *LogMonitor logger *LogMonitor
// capture fields // capture fields
enableCaptures bool enableCaptures bool
captures map[int]ReqRespCapture // map for O(1) lookup by ID captureCache *cache.Cache // zstd-compressed CBOR of ReqRespCapture
captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total size in bytes
maxCaptureSize int // max bytes for captures
} }
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the // newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
// capture buffer size in megabytes; 0 disables captures. // capture buffer size in megabytes; 0 disables captures.
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor { func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
return &metricsMonitor{ mm := &metricsMonitor{
logger: logger, logger: logger,
maxMetrics: maxMetrics, maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0, enableCaptures: captureBufferMB > 0,
captures: make(map[int]ReqRespCapture),
captureOrder: make([]int, 0),
captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024,
} }
if captureBufferMB > 0 {
mm.captureCache = cache.New(captureBufferMB * 1024 * 1024)
}
return mm
} }
// addMetrics adds a new metric to the collection and publishes an event. // queueMetrics adds a new metric to the collection without emitting an event.
// Returns the assigned metric ID. // Returns the assigned metric ID. Call emitMetric after capture setup.
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int { func (mp *metricsMonitor) queueMetrics(metric ActivityLogEntry) int {
mp.mu.Lock() mp.mu.Lock()
defer mp.mu.Unlock() defer mp.mu.Unlock()
@@ -103,58 +150,75 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
if len(mp.metrics) > mp.maxMetrics { if len(mp.metrics) > mp.maxMetrics {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:] mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
} }
event.Emit(TokenMetricsEvent{Metrics: metric})
return metric.ID return metric.ID
} }
// addCapture adds a new capture to the buffer with size-based eviction. // emitMetric publishes an ActivityLogEvent for the given metric.
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize. func (mp *metricsMonitor) emitMetric(metric ActivityLogEntry) {
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) { event.Emit(ActivityLogEvent{Metrics: metric})
if !mp.enableCaptures {
return
}
mp.mu.Lock()
defer mp.mu.Unlock()
captureSize := capture.Size()
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return
}
// Evict oldest (FIFO) until room available
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists {
mp.captureSize -= evicted.Size()
delete(mp.captures, oldestID)
}
}
mp.captures[capture.ID] = capture
mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize
} }
// getCaptureByID returns a capture by its ID, or nil if not found. // addCapture compresses and stores a capture in the cache.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture { // Returns true if the capture was stored, false otherwise.
mp.mu.RLock() func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
defer mp.mu.RUnlock() if !mp.enableCaptures {
return false
if capture, exists := mp.captures[id]; exists {
return &capture
} }
return nil
compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return false
}
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
return false
}
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
return true
}
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
if mp.captureCache == nil {
return nil, false
}
data, err := mp.captureCache.Get(id)
if err != nil {
return nil, false
}
return data, true
}
// getCaptureByID decompresses and unmarshals a capture by ID.
// Returns nil if the capture is not found or decompression fails.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
if mp.captureCache == nil {
return nil
}
data, exists := mp.getCompressedBytes(id)
if !exists {
return nil
}
capture, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return capture
} }
// getMetrics returns a copy of the current metrics // getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics { func (mp *metricsMonitor) getMetrics() []ActivityLogEntry {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
result := make([]TokenMetrics, len(mp.metrics)) result := make([]ActivityLogEntry, len(mp.metrics))
copy(result, mp.metrics) copy(result, mp.metrics)
return result return result
} }
@@ -163,22 +227,52 @@ func (mp *metricsMonitor) getMetrics() []TokenMetrics {
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) { func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
if mp.captureCache == nil {
return json.Marshal(mp.metrics)
}
// Make a copy with up-to-date has_capture from cache
result := make([]ActivityLogEntry, len(mp.metrics))
for i, m := range mp.metrics {
m.HasCapture = mp.captureCache.Has(m.ID)
result[i] = m
}
return json.Marshal(result)
} }
// wrapHandler wraps the proxy handler to extract token metrics // Capture field flags for controlling what is saved in ReqRespCapture.
type captureFields uint
const (
captureNone captureFields = 1 << iota
captureReqHeaders
captureReqBody
captureRespHeaders
captureRespBody
)
const (
captureReqAll = captureReqHeaders | captureReqBody
captureRespAll = captureRespHeaders | captureRespBody
captureAll = captureReqAll | captureRespAll
)
// wrapHandler wraps the proxy handler to extract token metrics.
// captureFields controls what is saved in the ReqRespCapture using bitwise flags.
// if wrapHandler returns an error it is safe to assume that no // if wrapHandler returns an error it is safe to assume that no
// data was sent to the client // data was sent to the client
func (mp *metricsMonitor) wrapHandler( func (mp *metricsMonitor) wrapHandler(
modelID string, modelID string,
writer gin.ResponseWriter, writer gin.ResponseWriter,
request *http.Request, request *http.Request,
captureFields captureFields,
next func(modelID string, w http.ResponseWriter, r *http.Request) error, next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error { ) error {
// Capture request body and headers if captures enabled // Capture request body and headers if captures enabled
var reqBody []byte var reqBody []byte
var reqHeaders map[string]string var reqHeaders map[string]string
if mp.enableCaptures { if mp.enableCaptures && (captureFields&captureReqBody) != 0 {
if request.Body != nil { if request.Body != nil {
var err error var err error
reqBody, err = io.ReadAll(request.Body) reqBody, err = io.ReadAll(request.Body)
@@ -188,6 +282,8 @@ func (mp *metricsMonitor) wrapHandler(
request.Body.Close() request.Body.Close()
request.Body = io.NopCloser(bytes.NewBuffer(reqBody)) request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
} }
}
if mp.enableCaptures && (captureFields&captureReqHeaders) != 0 {
reqHeaders = make(map[string]string) reqHeaders = make(map[string]string)
for key, values := range request.Header { for key, values := range request.Header {
if len(values) > 0 { if len(values) > 0 {
@@ -211,22 +307,28 @@ func (mp *metricsMonitor) wrapHandler(
// after this point we have to assume that data was sent to the client // after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients // and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK { // Initialize default metrics - recorded for every request
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path) tm := ActivityLogEntry{
return nil Timestamp: time.Now(),
Model: modelID,
ReqPath: request.URL.Path,
RespContentType: recorder.Header().Get("Content-Type"),
RespStatusCode: recorder.Status(),
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
} }
// Initialize default metrics - these will always be recorded if recorder.Status() != http.StatusOK {
tm := TokenMetrics{ mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), request.URL.Path)
Timestamp: time.Now(), tm.ID = mp.queueMetrics(tm)
Model: modelID, mp.emitMetric(tm)
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), return nil
} }
body := recorder.body.Bytes() body := recorder.body.Bytes()
if len(body) == 0 { if len(body) == 0 {
mp.logger.Warn("metrics: empty body, recording minimal metrics") mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm) tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil return nil
} }
@@ -236,7 +338,8 @@ func (mp *metricsMonitor) wrapHandler(
body, err = decompressBody(body, encoding) body, err = decompressBody(body, encoding)
if err != nil { if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm) tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm)
return nil return nil
} }
} }
@@ -244,7 +347,8 @@ func (mp *metricsMonitor) wrapHandler(
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else { } else {
tm = parsed tm.Tokens = parsed.Tokens
tm.DurationMs = parsed.DurationMs
} }
} else { } else {
if gjson.ValidBytes(body) { if gjson.ValidBytes(body) {
@@ -264,7 +368,8 @@ func (mp *metricsMonitor) wrapHandler(
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path) mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else { } else {
tm = parsedMetrics tm.Tokens = parsedMetrics.Tokens
tm.DurationMs = parsedMetrics.DurationMs
} }
} }
} else { } else {
@@ -275,39 +380,50 @@ func (mp *metricsMonitor) wrapHandler(
// Build capture if enabled and determine if it will be stored // Build capture if enabled and determine if it will be stored
var capture *ReqRespCapture var capture *ReqRespCapture
if mp.enableCaptures { if mp.enableCaptures {
respHeaders := make(map[string]string) var respHeaders map[string]string
for key, values := range recorder.Header() { var respBody []byte
if len(values) > 0 { if (captureFields & captureRespHeaders) != 0 {
respHeaders[key] = values[0] respHeaders = make(map[string]string)
for key, values := range recorder.Header() {
if len(values) > 0 {
respHeaders[key] = values[0]
}
} }
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
}
if (captureFields & captureRespBody) != 0 {
respBody = body
} }
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
capture = &ReqRespCapture{ capture = &ReqRespCapture{
ReqPath: request.URL.Path, ReqPath: request.URL.Path,
ReqHeaders: reqHeaders, ReqHeaders: reqHeaders,
ReqBody: reqBody, ReqBody: reqBody,
RespHeaders: respHeaders, RespHeaders: respHeaders,
RespBody: body, RespBody: respBody,
}
// Only set HasCapture if the capture will actually be stored (not too large)
if capture.Size() <= mp.maxCaptureSize {
tm.HasCapture = true
} }
} }
metricID := mp.addMetrics(tm) metricID := mp.queueMetrics(tm)
tm.ID = metricID
// Store capture if enabled // Store capture if enabled
if capture != nil { if capture != nil {
capture.ID = metricID capture.ID = metricID
mp.addCapture(*capture) if mp.addCapture(*capture) {
tm.HasCapture = true
mp.mu.Lock()
mp.metrics[len(mp.metrics)-1].HasCapture = true
mp.mu.Unlock()
}
} }
mp.emitMetric(tm)
return nil return nil
} }
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) { func processStreamingResponse(modelID string, start time.Time, body []byte) (ActivityLogEntry, error) {
// Iterate **backwards** through the body looking for the data payload with // Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split. // usage data. This avoids allocating a slice of all lines via bytes.Split.
@@ -361,10 +477,10 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
} }
} }
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream") return ActivityLogEntry{}, fmt.Errorf("no valid JSON data found in stream")
} }
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) { func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (ActivityLogEntry, error) {
wallDurationMs := int(time.Since(start).Milliseconds()) wallDurationMs := int(time.Since(start).Milliseconds())
// default values // default values
@@ -414,15 +530,17 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
} }
} }
return TokenMetrics{ return ActivityLogEntry{
Timestamp: time.Now(), Timestamp: time.Now(),
Model: modelID, Model: modelID,
CachedTokens: cachedTokens, Tokens: TokenMetrics{
InputTokens: inputTokens, CachedTokens: cachedTokens,
OutputTokens: outputTokens, InputTokens: inputTokens,
PromptPerSecond: promptPerSecond, OutputTokens: outputTokens,
TokensPerSecond: tokensPerSecond, PromptPerSecond: promptPerSecond,
DurationMs: durationMs, TokensPerSecond: tokensPerSecond,
},
DurationMs: durationMs,
}, nil }, nil
} }
+382 -168
View File
@@ -5,14 +5,17 @@ import (
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"encoding/json" "encoding/json"
"math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/cache"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -21,27 +24,29 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) { t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
InputTokens: 100, Tokens: TokenMetrics{
OutputTokens: 50, InputTokens: 100,
OutputTokens: 50,
},
} }
mm.addMetrics(metric) mm.queueMetrics(metric)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID) assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("increments ID for each metric", func(t *testing.T) { t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"}) mm.queueMetrics(ActivityLogEntry{Model: "model"})
} }
metrics := mm.getMetrics() metrics := mm.getMetrics()
@@ -56,9 +61,11 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
// Add 5 metrics // Add 5 metrics
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "model", Model: "model",
InputTokens: i, Tokens: TokenMetrics{
InputTokens: i,
},
}) })
} }
@@ -71,29 +78,32 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
assert.Equal(t, 4, metrics[2].ID) assert.Equal(t, 4, metrics[2].ID)
}) })
t.Run("emits TokenMetricsEvent", func(t *testing.T) { t.Run("emits ActivityLogEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
receivedEvent := make(chan TokenMetricsEvent, 1) receivedEvent := make(chan ActivityLogEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) { cancel := event.On(func(e ActivityLogEvent) {
receivedEvent <- e receivedEvent <- e
}) })
defer cancel() defer cancel()
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
InputTokens: 100, Tokens: TokenMetrics{
OutputTokens: 50, InputTokens: 100,
OutputTokens: 50,
},
} }
mm.addMetrics(metric) mm.queueMetrics(metric)
mm.emitMetric(metric)
select { select {
case evt := <-receivedEvent: case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID) assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model) assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens) assert.Equal(t, 100, evt.Metrics.Tokens.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens) assert.Equal(t, 50, evt.Metrics.Tokens.OutputTokens)
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event") t.Fatal("timeout waiting for event")
} }
@@ -110,8 +120,8 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns copy of metrics", func(t *testing.T) { t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{Model: "model1"}) mm.queueMetrics(ActivityLogEntry{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"}) mm.queueMetrics(ActivityLogEntry{Model: "model2"})
metrics1 := mm.getMetrics() metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics() metrics2 := mm.getMetrics()
@@ -134,7 +144,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, jsonData) assert.NotNil(t, jsonData)
var metrics []TokenMetrics var metrics []ActivityLogEntry
err = json.Unmarshal(jsonData, &metrics) err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 0, len(metrics))
@@ -142,23 +152,27 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON with metrics", func(t *testing.T) { t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "model1", Model: "model1",
InputTokens: 100, Tokens: TokenMetrics{
OutputTokens: 50, InputTokens: 100,
TokensPerSecond: 25.5, OutputTokens: 50,
TokensPerSecond: 25.5,
},
}) })
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "model2", Model: "model2",
InputTokens: 200, Tokens: TokenMetrics{
OutputTokens: 100, InputTokens: 200,
TokensPerSecond: 30.0, OutputTokens: 100,
TokensPerSecond: 30.0,
},
}) })
jsonData, err := mm.getMetricsJSON() jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err) assert.NoError(t, err)
var metrics []TokenMetrics var metrics []ActivityLogEntry
err = json.Unmarshal(jsonData, &metrics) err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(metrics)) assert.Equal(t, 2, len(metrics))
@@ -189,14 +203,14 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("successful request with timings data", func(t *testing.T) { t.Run("successful request with timings data", func(t *testing.T) {
@@ -225,17 +239,17 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens) assert.Equal(t, 20, metrics[0].Tokens.CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond) assert.Equal(t, 150.5, metrics[0].Tokens.PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond) assert.Equal(t, 25.5, metrics[0].Tokens.TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500 assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
}) })
@@ -264,18 +278,18 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence // When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens) assert.Equal(t, 10, metrics[0].Tokens.InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens) assert.Equal(t, 20, metrics[0].Tokens.OutputTokens)
}) })
t.Run("non-OK status code does not record metrics", func(t *testing.T) { t.Run("non-OK status code records partial metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0) mm := newMetricsMonitor(testLogger, 10, 0)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
@@ -288,11 +302,16 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, "/test", metrics[0].ReqPath)
assert.Equal(t, http.StatusBadRequest, metrics[0].RespStatusCode)
assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("empty response body records minimal metrics", func(t *testing.T) { t.Run("empty response body records minimal metrics", func(t *testing.T) {
@@ -307,14 +326,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("invalid JSON records minimal metrics", func(t *testing.T) { t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
@@ -331,14 +350,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("next handler error is propagated", func(t *testing.T) { t.Run("next handler error is propagated", func(t *testing.T) {
@@ -353,7 +372,7 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.Equal(t, expectedErr, err) assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
@@ -376,14 +395,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("infill request extracts timings from last array element", func(t *testing.T) { t.Run("infill request extracts timings from last array element", func(t *testing.T) {
@@ -415,17 +434,17 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 150, metrics[0].InputTokens) assert.Equal(t, 150, metrics[0].Tokens.InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens) assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
assert.Equal(t, 30, metrics[0].CachedTokens) assert.Equal(t, 30, metrics[0].Tokens.CachedTokens)
assert.Equal(t, 200.5, metrics[0].PromptPerSecond) assert.Equal(t, 200.5, metrics[0].Tokens.PromptPerSecond)
assert.Equal(t, 35.5, metrics[0].TokensPerSecond) assert.Equal(t, 35.5, metrics[0].Tokens.TokensPerSecond)
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800 assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
}) })
@@ -445,14 +464,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
} }
@@ -506,7 +525,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
} }
func TestMetricsMonitor_Concurrent(t *testing.T) { func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) { t.Run("concurrent queueMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000, 0) mm := newMetricsMonitor(testLogger, 1000, 0)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -518,10 +537,12 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ { for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{ mm.queueMetrics(ActivityLogEntry{
Model: "test-model", Model: "test-model",
InputTokens: id*1000 + j, Tokens: TokenMetrics{
OutputTokens: j, InputTokens: id*1000 + j,
OutputTokens: j,
},
}) })
} }
}(i) }(i)
@@ -541,7 +562,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
// Writer goroutine // Writer goroutine
go func() { go func() {
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"}) mm.queueMetrics(ActivityLogEntry{Model: "test-model"})
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
done <- true done <- true
@@ -585,10 +606,10 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
metrics, err := parseMetrics("test-model", start, usage, timings) metrics, err := parseMetrics("test-model", start, usage, timings)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 5, metrics.InputTokens) assert.Equal(t, 5, metrics.Tokens.InputTokens)
assert.Equal(t, 1, metrics.OutputTokens) assert.Equal(t, 1, metrics.Tokens.OutputTokens)
assert.Equal(t, 10.0, metrics.PromptPerSecond) assert.Equal(t, 10.0, metrics.Tokens.PromptPerSecond)
assert.Equal(t, 2.0, metrics.TokensPerSecond) assert.Equal(t, 2.0, metrics.Tokens.TokensPerSecond)
assert.GreaterOrEqual(t, metrics.DurationMs, 5000) assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
}) })
@@ -622,14 +643,14 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values // Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("handles missing cache_n in timings", func(t *testing.T) { t.Run("handles missing cache_n in timings", func(t *testing.T) {
@@ -657,12 +678,12 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present assert.Equal(t, -1, metrics[0].Tokens.CachedTokens) // Default value when not present
}) })
} }
@@ -692,13 +713,13 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) { t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
@@ -721,14 +742,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("v1/responses format with nested response.usage", func(t *testing.T) { t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
@@ -750,14 +771,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 17, metrics[0].InputTokens) assert.Equal(t, 17, metrics[0].Tokens.InputTokens)
assert.Equal(t, 23, metrics[0].OutputTokens) assert.Equal(t, 23, metrics[0].Tokens.OutputTokens)
}) })
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) { t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
@@ -776,14 +797,14 @@ data: [DONE]
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
} }
@@ -791,20 +812,22 @@ data: [DONE]
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) { func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000, 0) mm := newMetricsMonitor(testLogger, 1000, 0)
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
CachedTokens: 100, Tokens: TokenMetrics{
InputTokens: 500, CachedTokens: 100,
OutputTokens: 250, InputTokens: 500,
PromptPerSecond: 1200.5, OutputTokens: 250,
TokensPerSecond: 45.8, PromptPerSecond: 1200.5,
DurationMs: 5000, TokensPerSecond: 45.8,
Timestamp: time.Now(), },
DurationMs: 5000,
Timestamp: time.Now(),
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mm.addMetrics(metric) mm.queueMetrics(metric)
} }
} }
@@ -812,20 +835,22 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently // Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100, 0) mm := newMetricsMonitor(testLogger, 100, 0)
metric := TokenMetrics{ metric := ActivityLogEntry{
Model: "test-model", Model: "test-model",
CachedTokens: 100, Tokens: TokenMetrics{
InputTokens: 500, CachedTokens: 100,
OutputTokens: 250, InputTokens: 500,
PromptPerSecond: 1200.5, OutputTokens: 250,
TokensPerSecond: 45.8, PromptPerSecond: 1200.5,
DurationMs: 5000, TokensPerSecond: 45.8,
Timestamp: time.Now(), },
DurationMs: 5000,
Timestamp: time.Now(),
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mm.addMetrics(metric) mm.queueMetrics(metric)
} }
} }
@@ -854,14 +879,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens) assert.Equal(t, 100, metrics[0].Tokens.InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].Tokens.OutputTokens)
}) })
t.Run("deflate encoded response", func(t *testing.T) { t.Run("deflate encoded response", func(t *testing.T) {
@@ -888,14 +913,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 200, metrics[0].InputTokens) assert.Equal(t, 200, metrics[0].Tokens.InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens) assert.Equal(t, 75, metrics[0].Tokens.OutputTokens)
}) })
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) { t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
@@ -916,14 +941,14 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) // Should not return error, just log warning assert.NoError(t, err) // Should not return error, just log warning
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model) assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens) assert.Equal(t, 0, metrics[0].Tokens.InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens) assert.Equal(t, 0, metrics[0].Tokens.OutputTokens)
}) })
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) { t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
@@ -943,38 +968,37 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, 300, metrics[0].InputTokens) assert.Equal(t, 300, metrics[0].Tokens.InputTokens)
assert.Equal(t, 100, metrics[0].OutputTokens) assert.Equal(t, 100, metrics[0].Tokens.OutputTokens)
}) })
} }
func TestReqRespCapture_Size(t *testing.T) { func TestReqRespCapture_CompressedSize(t *testing.T) {
t.Run("calculates size correctly", func(t *testing.T) { t.Run("compressed size is smaller than uncompressed", func(t *testing.T) {
capture := ReqRespCapture{ capture := ReqRespCapture{
ID: 1, ID: 1,
ReqPath: "/v1/chat/completions", // 20 bytes ReqPath: "/v1/chat/completions",
ReqHeaders: map[string]string{ ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`),
"Content-Type": "application/json", // 12 + 16 = 28 RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`),
},
ReqBody: []byte("request body"), // 12 bytes
RespHeaders: map[string]string{
"X-Test": "value", // 6 + 5 = 11
},
RespBody: []byte("response body"), // 13 bytes
} }
// Expected: 20 + 12 + 13 + 28 + 11 = 84 compressed, uncompressed, err := compressCapture(&capture)
assert.Equal(t, 84, capture.Size()) assert.NoError(t, err)
assert.Greater(t, uncompressed, 0)
assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed)
}) })
t.Run("handles empty capture", func(t *testing.T) { t.Run("empty capture produces compressed output", func(t *testing.T) {
capture := ReqRespCapture{} capture := ReqRespCapture{}
assert.Equal(t, 0, capture.Size()) compressed, _, err := compressCapture(&capture)
assert.NoError(t, err)
assert.NotNil(t, compressed)
assert.True(t, len(compressed) > 0)
}) })
} }
@@ -1002,21 +1026,27 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
} }
mm.addCapture(capture) mm.addCapture(capture)
retrieved := mm.getCaptureByID(0) captured := mm.getCaptureByID(0)
assert.NotNil(t, retrieved) assert.NotNil(t, captured)
assert.Equal(t, 0, retrieved.ID) assert.Equal(t, 0, captured.ID)
assert.Equal(t, []byte("test request"), retrieved.ReqBody) assert.Equal(t, []byte("test request"), captured.ReqBody)
assert.Equal(t, []byte("test response"), retrieved.RespBody) assert.Equal(t, []byte("test response"), captured.RespBody)
}) })
t.Run("evicts oldest when exceeding max size", func(t *testing.T) { t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 // Set small limit for test // Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
mm.captureCache = cache.New(450)
// Add captures that will exceed the limit // Use random-looking data that doesn't compress well with zstd
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)} rng := rand.New(rand.NewSource(42))
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)} capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)}
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)} rng.Read(capture1.ReqBody)
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)}
rng.Read(capture2.ReqBody)
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)}
rng.Read(capture3.ReqBody)
mm.addCapture(capture1) mm.addCapture(capture1)
mm.addCapture(capture2) mm.addCapture(capture2)
@@ -1030,10 +1060,12 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
t.Run("skips capture larger than max size", func(t *testing.T) { t.Run("skips capture larger than max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 mm.captureCache = cache.New(100)
// Add a capture larger than max // Use random data that doesn't compress well to create an oversized capture
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)} rng := rand.New(rand.NewSource(99))
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)}
rng.Read(largeCapture.ReqBody)
mm.addCapture(largeCapture) mm.addCapture(largeCapture)
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored") assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
@@ -1047,18 +1079,39 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
assert.Nil(t, mm.getCaptureByID(999)) assert.Nil(t, mm.getCaptureByID(999))
}) })
t.Run("returns capture by ID", func(t *testing.T) { t.Run("returns decompressed capture by ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{ capture := ReqRespCapture{
ID: 42, ID: 42,
ReqBody: []byte("test"), ReqBody: []byte("test request"),
RespBody: []byte("test response"),
} }
mm.addCapture(capture) mm.addCapture(capture)
retrieved := mm.getCaptureByID(42) captured := mm.getCaptureByID(42)
assert.NotNil(t, retrieved) assert.NotNil(t, captured)
assert.Equal(t, 42, retrieved.ID) assert.Equal(t, 42, captured.ID)
assert.Equal(t, []byte("test request"), captured.ReqBody)
assert.Equal(t, []byte("test response"), captured.RespBody)
})
t.Run("stores data as compressed bytes", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
ID: 42,
ReqBody: []byte("test request body"),
RespBody: []byte("test response body"),
}
mm.addCapture(capture)
compressed, exists := mm.getCompressedBytes(42)
assert.True(t, exists)
assert.NotNil(t, compressed)
// Compressed data should not be valid CBOR (it's zstd-compressed)
var decoded ReqRespCapture
assert.Error(t, cbor.Unmarshal(compressed, &decoded))
}) })
} }
@@ -1127,7 +1180,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
// Check metric was recorded // Check metric was recorded
@@ -1135,7 +1188,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
metricID := metrics[0].ID metricID := metrics[0].ID
// Check capture was stored with same ID // Check capture was stored with same ID (decompressed)
capture := mm.getCaptureByID(metricID) capture := mm.getCaptureByID(metricID)
assert.NotNil(t, capture) assert.NotNil(t, capture)
assert.Equal(t, metricID, capture.ID) assert.Equal(t, metricID, capture.ID)
@@ -1165,7 +1218,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureAll, nextHandler)
assert.NoError(t, err) assert.NoError(t, err)
// Metrics should still be recorded // Metrics should still be recorded
@@ -1173,7 +1226,168 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
// But no capture // But no capture
capture := mm.getCaptureByID(metrics[0].ID) assert.Nil(t, mm.getCaptureByID(metrics[0].ID))
assert.Nil(t, capture) })
}
func TestMetricsMonitor_WrapHandler_PartialCaptures(t *testing.T) {
requestBody := `{"model": "test"}`
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Custom", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
t.Run("only request headers", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("only request body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("only response headers", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespHeaders, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
assert.Nil(t, capture.RespBody)
})
t.Run("only response body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
t.Run("captureReqAll", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqAll, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("captureRespAll", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureRespAll, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
assert.Equal(t, []byte(responseBody), capture.RespBody)
})
t.Run("no flags", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureFields(0), nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Nil(t, capture.ReqHeaders)
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Nil(t, capture.RespBody)
})
t.Run("mixed flags req headers and resp body", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 100)
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, captureReqHeaders|captureRespBody, nextHandler)
assert.NoError(t, err)
capture := mm.getCaptureByID(mm.getMetrics()[0].ID)
assert.NotNil(t, capture)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Nil(t, capture.ReqBody)
assert.Nil(t, capture.RespHeaders)
assert.Equal(t, []byte(responseBody), capture.RespBody)
}) })
} }
+1 -1
View File
@@ -42,7 +42,7 @@ func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second, Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext, }).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second, TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second, ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
+63 -5
View File
@@ -77,6 +77,9 @@ type Process struct {
// used for testing to override the default value // used for testing to override the default value
gracefulStopTimeout time.Duration gracefulStopTimeout time.Duration
// used for testing to bypass subprocess and reverse proxy
testHandler http.Handler
// track the number of failed starts // track the number of failed starts
failedStartCount int failedStartCount int
} }
@@ -102,7 +105,7 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: time.Duration(config.Timeouts.Connect) * time.Second, Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
}).DialContext, }).DialContext,
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second, TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second, ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
@@ -236,6 +239,49 @@ func (p *Process) forceState(newState ProcessState) {
// at any time. // at any time.
func (p *Process) start() error { func (p *Process) start() error {
// test-only fast path: skip subprocess, health check, and TTL goroutine
if p.testHandler != nil {
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
if curState == StateStarting {
p.waitStarting.Wait()
curState = p.CurrentState()
if curState == StateReady {
return nil
}
return fmt.Errorf("process was already starting but wound up in state %v", curState)
}
return fmt.Errorf("process was in state %v when start() was called", curState)
}
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
defer p.waitStarting.Done()
// Mimic the real stop path: cancelUpstream transitions
// StateStopping -> StateStopped and closes cmdWaitChan,
// matching what waitForCmd does for real subprocesses.
ch := make(chan struct{})
p.cmdMutex.Lock()
p.cancelUpstream = func() {
if curState := p.CurrentState(); curState == StateStopping {
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
p.forceState(StateStopped)
}
} else {
p.forceState(StateStopped)
}
close(ch)
}
p.cmdWaitChan = ch
p.cmdMutex.Unlock()
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
}
p.failedStartCount = 0
return nil
}
if p.config.Proxy == "" { if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing") return fmt.Errorf("can not start(), upstream proxy missing")
} }
@@ -386,7 +432,10 @@ func (p *Process) start() error {
// Stop will wait for inflight requests to complete before stopping the process. // Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() { func (p *Process) Stop() {
// guard to prevent multiple goroutines from stopping
if !isValidTransition(p.CurrentState(), StateStopping) { if !isValidTransition(p.CurrentState(), StateStopping) {
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
return return
} }
@@ -399,13 +448,17 @@ func (p *Process) Stop() {
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() { func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
// guard to prevent multiple goroutines from stopping the process
enterState := p.CurrentState()
if !isValidTransition(enterState, StateStopping) {
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
return return
} }
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState()) p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if curState, err := p.swapState(enterState, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
return return
} }
@@ -577,6 +630,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if !srw.waitForCompletion(completionTimeout) { if !srw.waitForCompletion(completionTimeout) {
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout) p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
} }
}
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
} else if srw != nil {
p.reverseProxy.ServeHTTP(srw, r) p.reverseProxy.ServeHTTP(srw, r)
} else { } else {
p.reverseProxy.ServeHTTP(w, r) p.reverseProxy.ServeHTTP(w, r)
+36
View File
@@ -24,6 +24,22 @@ type ProcessGroup struct {
// map of current processes // map of current processes
processes map[string]*Process processes map[string]*Process
lastUsedProcess string lastUsedProcess string
// inflight tracks fast-path requests (requests for the already-selected
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
// and Done() on completion; a concurrent swap request calls inflight.Wait()
// under pg.Lock before stopping the current process. Without this tracking,
// a fast-path request that has released pg.Lock but has not yet called
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
// killed mid-request.
inflight sync.WaitGroup
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
// the fast path after pg.Lock is released but before the request is
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
// request at the exact race window to deterministically reproduce the
// fast-path vs swap race.
testDelayFastPath func()
} }
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
@@ -64,6 +80,13 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
pg.Lock() pg.Lock()
if pg.lastUsedProcess != modelID { if pg.lastUsedProcess != modelID {
// Wait for in-flight fast-path requests to drain before stopping
// the previous process. Without this, a fast-path request that has
// released pg.Lock but has not yet incremented
// Process.inFlightRequests races with Stop() and can be killed
// mid-request.
pg.inflight.Wait()
// is there something already running? // is there something already running?
if pg.lastUsedProcess != "" { if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop() pg.processes[pg.lastUsedProcess].Stop()
@@ -78,7 +101,16 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
pg.Unlock() pg.Unlock()
return nil return nil
} }
// Fast path: register this request in inflight before releasing
// pg.Lock so a concurrent swap will wait for it to complete.
pg.inflight.Add(1)
defer pg.inflight.Done()
pg.Unlock() pg.Unlock()
if pg.testDelayFastPath != nil {
pg.testDelayFastPath()
}
} }
pg.processes[modelID].ProxyRequest(writer, request) pg.processes[modelID].ProxyRequest(writer, request)
@@ -123,6 +155,10 @@ func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock() pg.Lock()
defer pg.Unlock() defer pg.Unlock()
if strategy != StopImmediately {
pg.inflight.Wait()
}
if len(pg.processes) == 0 { if len(pg.processes) == 0 {
return return
} }
+226
View File
@@ -4,11 +4,14 @@ import (
"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"runtime"
"sync" "sync"
"testing" "testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
@@ -95,6 +98,229 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
wg.Wait() wg.Wait()
} }
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
// request cannot stop the current process while a fast-path request (for the
// already-selected model) is in flight. Without ProcessGroup-level inflight
// tracking, a fast-path request that has released pg.Lock but has not yet
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
// process is killed mid-request.
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
cfg := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2"},
},
},
})
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
defer pg.StopProcesses(StopImmediately)
// Bypass real subprocesses so the test is fast and deterministic.
pg.processes["model1"].testHandler = newTestHandler("model1")
pg.processes["model2"].testHandler = newTestHandler("model2")
// Prime: run a request through model1 via the swap path so that
// lastUsedProcess == "model1" and subsequent model1 requests take the
// fast path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
// Fast-path hook: signal arrival at the race window, then wait for
// release. This parks R2 deterministically at the point where pg.Lock
// has been released but Process.inFlightRequests has not yet been
// incremented — the exact window the race exploits.
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
pg.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
// R2: fast-path request for model1. Will pause at the test hook.
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
}()
// Deterministically wait for R2 to reach the race window.
<-r2Reached
// R3: swap request for model2. Must wait for R2 to finish before touching
// model1, otherwise model1 gets killed mid-request.
r3Done := make(chan struct{})
w3 := httptest.NewRecorder()
go func() {
defer close(r3Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
}()
// Spin until R3 has acquired pg.Lock and entered the swap critical
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
// still holding the lock, so TryLock keeps failing.
for pg.TryLock() {
pg.Unlock()
runtime.Gosched()
}
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
// nothing changes, so we wait the full window. In the buggy code, R3
// will Stop() model1 and start serving via model2 within microseconds —
// we exit early once the mutation is observable.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if pg.processes["model1"].CurrentState() != StateReady ||
pg.processes["model2"].CurrentState() != StateStopped {
break
}
done := false
select {
case <-r3Done:
done = true
default:
}
if done {
break
}
runtime.Gosched()
}
// Invariant: R3 must be blocked while R2 is still in flight.
select {
case <-r3Done:
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
default:
}
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
"model1 must stay Ready while a fast-path request is in flight")
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
"model2 must not be started until R2 finishes and model1 is swapped out")
// Release R2 and let both requests finish.
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
assert.Equal(t, http.StatusOK, w3.Code)
assert.Contains(t, w3.Body.String(), "model2")
}
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
// process while a fast-path ProxyRequest is in the [pg.Unlock,
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
// StopProcesses, the external caller bypasses the inflight guard and kills the
// process mid-request.
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
cfg := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2"},
},
},
})
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
defer pg.StopProcesses(StopImmediately)
pg.processes["model1"].testHandler = newTestHandler("model1")
pg.processes["model2"].testHandler = newTestHandler("model2")
// Prime: model1 is active so subsequent model1 requests take the fast path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
// Park a fast-path request at the race window.
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
pg.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
}()
<-r2Reached
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
// the group while a fast-path request is in flight.
r3Done := make(chan struct{})
go func() {
defer close(r3Done)
pg.StopProcesses(StopWaitForInflightRequest)
}()
// Spin until StopProcesses has acquired pg.Lock.
for pg.TryLock() {
pg.Unlock()
runtime.Gosched()
}
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
// and model1 stays Ready. In the buggy code it proceeds immediately and
// kills model1.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if pg.processes["model1"].CurrentState() != StateReady {
break
}
select {
case <-r3Done:
goto done
default:
}
runtime.Gosched()
}
done:
select {
case <-r3Done:
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
default:
}
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
"model1 must stay Ready while a fast-path request is in flight")
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest) defer pg.StopProcesses(StopWaitForInflightRequest)
+400 -285
View File
@@ -77,6 +77,9 @@ type ProxyManager struct {
processGroups map[string]*ProcessGroup processGroups map[string]*ProcessGroup
// matrix-based swap (mutually exclusive with processGroups)
matrix *Matrix
inFlightCounter *InflightCounter inFlightCounter *InflightCounter
// shutdown signaling // shutdown signaling
@@ -203,10 +206,14 @@ func New(proxyConfig config.Config) *ProxyManager {
peerProxy: peerProxy, peerProxy: peerProxy,
} }
// create the process groups // create either matrix or process groups (mutually exclusive)
for groupID := range proxyConfig.Groups { if proxyConfig.Matrix != nil {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger) pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup } else {
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
} }
pm.setupGinEngine() pm.setupGinEngine()
@@ -225,18 +232,29 @@ func New(proxyConfig config.Config) *ProxyManager {
} }
proxyLogger.Infof("Preloading model: %s", modelID) proxyLogger.Infof("Preloading model: %s", modelID)
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil { var preloadErr error
req, _ := http.NewRequest("GET", "/", nil)
if pm.matrix != nil {
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
preloadErr = err
} else {
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
}
}
if preloadErr != nil {
event.Emit(ModelPreloadedEvent{ event.Emit(ModelPreloadedEvent{
ModelName: modelID, ModelName: modelID,
Success: false, Success: false,
}) })
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err) proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
continue continue
} else { } else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{ event.Emit(ModelPreloadedEvent{
ModelName: modelID, ModelName: modelID,
Success: true, Success: true,
@@ -314,41 +332,77 @@ func (pm *ProxyManager) setupGinEngine() {
// Set up routes using the Gin engine // Set up routes using the Gin engine
// Protected routes use pm.apiKeyAuth() middleware // Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) llmHandler := pm.mkProxyJSONHandler(captureAll)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support anthropic count_tokens API (Also added in the above PR) // Support anthropic count_tokens API (Also added in the above PR)
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support embeddings and reranking // Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /reranking endpoint + aliases // llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /infill endpoint for code infilling // llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// llama-server's /completion endpoint // llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), llmHandler)
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST(
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) "/v1/audio/speech",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST(
"/v1/audio/voices",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqHeaders|captureRespAll),
)
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST(
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler) "/v1/audio/transcriptions",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders|captureRespBody),
)
pm.ginEngine.POST(
"/v1/images/generations",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST(
"/v1/images/edits",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkPostFormHandler(captureReqHeaders|captureRespHeaders),
)
// sd.cpp /sdapi/v1 endpoints // sd.cpp /sdapi/v1 endpoints
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.ginEngine.POST("/sdapi/v1/txt2img",
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler) pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqAll|captureRespHeaders),
)
pm.ginEngine.POST("/sdapi/v1/img2img",
pm.apiKeyAuth(),
pm.trackInflight(),
pm.mkProxyJSONHandler(captureReqHeaders|captureRespHeaders),
)
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler) pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
@@ -453,6 +507,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
if pm.matrix != nil {
pm.matrix.StopProcesses(strategy)
return
}
// stop Processes in parallel // stop Processes in parallel
var wg sync.WaitGroup var wg sync.WaitGroup
for _, processGroup := range pm.processGroups { for _, processGroup := range pm.processGroups {
@@ -473,6 +532,12 @@ func (pm *ProxyManager) Shutdown() {
pm.proxyLogger.Debug("Shutdown() called in proxy manager") pm.proxyLogger.Debug("Shutdown() called in proxy manager")
if pm.matrix != nil {
pm.matrix.Shutdown()
pm.shutdownCancel()
return
}
var wg sync.WaitGroup var wg sync.WaitGroup
// Send shutdown signal to all process in groups // Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups { for _, processGroup := range pm.processGroups {
@@ -639,10 +704,16 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
processGroup, err := pm.swapProcessGroup(modelID) var handler func(string, http.ResponseWriter, *http.Request) error
if err != nil { if pm.matrix != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) handler = pm.matrix.ProxyRequest
return } else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
handler = processGroup.ProxyRequest
} }
// rewrite the path // rewrite the path
@@ -651,13 +722,13 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request // attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" { if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, captureNone, handler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath) pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return return
} }
} else { } else {
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil { if err := handler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath) pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return return
@@ -665,270 +736,294 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
} }
} }
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { func (pm *ProxyManager) mkProxyJSONHandler(cf captureFields) func(*gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body) return func(c *gin.Context) {
if err != nil { bodyBytes, err := io.ReadAll(c.Request.Body)
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return return
} }
// issue #69 allow custom model names to be sent to upstream requestedModel := gjson.GetBytes(bodyBytes, "model").String()
useModelName := pm.config.Models[modelID].UseModelName if requestedModel == "" {
if useModelName != "" { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) return
if err != nil { }
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
modelID, found := pm.config.RealModelName(requestedModel)
if found {
var localHandler func(string, http.ResponseWriter, *http.Request) error
if pm.matrix != nil {
localHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
localHandler = processGroup.ProxyRequest
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
// issue #453 set/override parameters in the JSON body
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = localHandler
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, cf, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return return
} }
} }
}
}
// issue #174 strip parameters from the JSON body // mkPostFormHandler creates a POST form handler for inference backends
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams() // with a custom captureFields to filter out large binary requests or responses.
if err != nil { // just log it and continue func (pm *ProxyManager) mkPostFormHandler(cf captureFields) func(*gin.Context) {
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error()) return func(c *gin.Context) {
} else { // Parse multipart form
for _, param := range stripParams { if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param) pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return return
} }
} }
} }
// issue #453 set/override parameters in the JSON body // Copy all files from the original request
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams() for key, fileHeaders := range c.Request.MultipartForm.File {
for _, key := range setParamKeys { for _, fileHeader := range fileHeaders {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key) formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) if err != nil {
if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) return
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
} }
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request file, err := fileHeader.Open()
for key, fileHeaders := range c.Request.MultipartForm.File { if err != nil {
for _, fileHeader := range fileHeaders { pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) return
if err != nil { }
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open() if _, err = io.Copy(formFile, file); err != nil {
if err != nil { file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return return
} }
if _, err = io.Copy(formFile, file); err != nil {
file.Close() file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") }
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if pm.metricsMonitor != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, modifiedReq, cf, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
} else {
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return return
} }
file.Close()
} }
} }
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
} }
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) { func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
@@ -942,14 +1037,18 @@ func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
var modelID string var modelID string
if realModelID, found := pm.config.RealModelName(requestedModel); found { if realModelID, found := pm.config.RealModelName(requestedModel); found {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
modelID = realModelID modelID = realModelID
if pm.matrix != nil {
nextHandler = pm.matrix.ProxyRequest
} else {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
nextHandler = processGroup.ProxyRequest
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
modelID = requestedModel modelID = requestedModel
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
@@ -1048,9 +1147,9 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json") context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response. runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups { if pm.matrix != nil {
for _, process := range processGroup.processes { for _, modelID := range pm.matrix.RunningModels() {
if process.CurrentState() == StateReady { if process, ok := pm.matrix.GetProcess(modelID); ok {
runningProcesses = append(runningProcesses, gin.H{ runningProcesses = append(runningProcesses, gin.H{
"model": process.ID, "model": process.ID,
"state": process.state, "state": process.state,
@@ -1062,6 +1161,22 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
}) })
} }
} }
} else {
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
"cmd": process.config.Cmd,
"proxy": process.config.Proxy,
"ttl": process.config.UnloadAfter,
"name": process.config.Name,
"description": process.config.Description,
})
}
}
}
} }
// Put the results under the `running` key. // Put the results under the `running` key.
+43 -32
View File
@@ -55,27 +55,28 @@ func (pm *ProxyManager) getModelStatus() []Model {
// Iterate over sorted keys // Iterate over sorted keys
for _, modelID := range modelIDs { for _, modelID := range modelIDs {
// Get process state // Get process state
processGroup := pm.findGroupByModelName(modelID)
state := "unknown" state := "unknown"
if processGroup != nil { var process *Process
process := processGroup.processes[modelID] if pm.matrix != nil {
if process != nil { process, _ = pm.matrix.GetProcess(modelID)
var stateStr string } else {
switch process.CurrentState() { processGroup := pm.findGroupByModelName(modelID)
case StateReady: if processGroup != nil {
stateStr = "ready" process = processGroup.processes[modelID]
case StateStarting: }
stateStr = "starting" }
case StateStopping: if process != nil {
stateStr = "stopping" switch process.CurrentState() {
case StateShutdown: case StateReady:
stateStr = "shutdown" state = "ready"
case StateStopped: case StateStarting:
stateStr = "stopped" state = "starting"
default: case StateStopping:
stateStr = "unknown" state = "stopping"
} case StateShutdown:
state = stateStr state = "shutdown"
case StateStopped:
state = "stopped"
} }
} }
models = append(models, Model{ models = append(models, Model{
@@ -157,7 +158,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
} }
} }
sendMetrics := func(metrics []TokenMetrics) { sendMetrics := func(metrics []ActivityLogEntry) {
jsonData, err := json.Marshal(metrics) jsonData, err := json.Marshal(metrics)
if err == nil { if err == nil {
select { select {
@@ -204,8 +205,8 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
/** /**
* Send Metrics data * Send Metrics data
*/ */
defer event.On(func(e TokenMetricsEvent) { defer event.On(func(e ActivityLogEvent) {
sendMetrics([]TokenMetrics{e.Metrics}) sendMetrics([]ActivityLogEntry{e.Metrics})
})() })()
/** /**
@@ -254,18 +255,23 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
return return
} }
processGroup := pm.findGroupByModelName(realModelName) var stopErr error
if processGroup == nil { if pm.matrix != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel)) stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
return } else {
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
}
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
} }
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil { if stopErr != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
return return
} else {
c.String(http.StatusOK, "OK")
} }
c.String(http.StatusOK, "OK")
} }
func (pm *ProxyManager) apiGetVersion(c *gin.Context) { func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
@@ -290,5 +296,10 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, capture) jsonBytes, err := json.Marshal(capture)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal capture"})
return
}
c.Data(http.StatusOK, "application/json", jsonBytes)
} }
+380 -339
View File
File diff suppressed because it is too large Load Diff
+101 -108
View File
@@ -37,21 +37,21 @@
} }
}, },
"node_modules/@emnapi/core": { "node_modules/@emnapi/core": {
"version": "1.9.0", "version": "1.9.2",
"resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.9.0.tgz", "resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.9.2.tgz",
"integrity": "sha512-0DQ98G9ZQZOxfUcQn1waV2yS8aWdZ6kJMbYCJB3oUBecjWYO1fqJ+a1DRfPF3O5JEkwqwP1A9QEN/9mYm2Yd0w==", "integrity": "sha512-UC+ZhH3XtczQYfOlu3lNEkdW/p4dsJ1r/bP7H8+rhao3TTTMO1ATq/4DdIi23XuGoFY+Cz0JmCbdVl0hz9jZcA==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
"dependencies": { "dependencies": {
"@emnapi/wasi-threads": "1.2.0", "@emnapi/wasi-threads": "1.2.1",
"tslib": "^2.4.0" "tslib": "^2.4.0"
} }
}, },
"node_modules/@emnapi/runtime": { "node_modules/@emnapi/runtime": {
"version": "1.9.0", "version": "1.9.2",
"resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.9.0.tgz", "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.9.2.tgz",
"integrity": "sha512-QN75eB0IH2ywSpRpNddCRfQIhmJYBCJ1x5Lb3IscKAL8bMnVAKnRg8dCoXbHzVLLH7P38N2Z3mtulB7W0J0FKw==", "integrity": "sha512-3U4+MIWHImeyu1wnmVygh5WlgfYDtyf0k8AbLhMFxOipihf6nrWC4syIm/SwEeec0mNSafiiNnMJwbza/Is6Lw==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
@@ -60,9 +60,9 @@
} }
}, },
"node_modules/@emnapi/wasi-threads": { "node_modules/@emnapi/wasi-threads": {
"version": "1.2.0", "version": "1.2.1",
"resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.0.tgz", "resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz",
"integrity": "sha512-N10dEJNSsUx41Z6pZsXU8FjPjpBEplgH24sfkmITrBED1/U2Esum9F3lfLrMjKHHjmi557zQn7kR9R+XWXu5Rg==", "integrity": "sha512-uTII7OYF+/Mes/MrcIOYp5yOtSMLBWSIoLPpcgwipoiKbli6k322tcoFsxoIIxPDqW01SQGAgko4EzZi2BNv2w==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
@@ -121,36 +121,28 @@
} }
}, },
"node_modules/@napi-rs/wasm-runtime": { "node_modules/@napi-rs/wasm-runtime": {
"version": "1.1.1", "version": "1.1.3",
"resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-1.1.1.tgz", "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-1.1.3.tgz",
"integrity": "sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==", "integrity": "sha512-xK9sGVbJWYb08+mTJt3/YV24WxvxpXcXtP6B172paPZ+Ts69Re9dAr7lKwJoeIx8OoeuimEiRZ7umkiUVClmmQ==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
"dependencies": { "dependencies": {
"@emnapi/core": "^1.7.1",
"@emnapi/runtime": "^1.7.1",
"@tybys/wasm-util": "^0.10.1" "@tybys/wasm-util": "^0.10.1"
}, },
"funding": { "funding": {
"type": "github", "type": "github",
"url": "https://github.com/sponsors/Brooooooklyn" "url": "https://github.com/sponsors/Brooooooklyn"
} },
}, "peerDependencies": {
"node_modules/@oxc-project/runtime": { "@emnapi/core": "^1.7.1",
"version": "0.115.0", "@emnapi/runtime": "^1.7.1"
"resolved": "https://registry.npmjs.org/@oxc-project/runtime/-/runtime-0.115.0.tgz",
"integrity": "sha512-Rg8Wlt5dCbXhQnsXPrkOjL1DTSvXLgb2R/KYfnf1/K+R0k6UMLEmbQXPM+kwrWqSmWA2t0B1EtHy2/3zikQpvQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": "^20.19.0 || >=22.12.0"
} }
}, },
"node_modules/@oxc-project/types": { "node_modules/@oxc-project/types": {
"version": "0.115.0", "version": "0.124.0",
"resolved": "https://registry.npmjs.org/@oxc-project/types/-/types-0.115.0.tgz", "resolved": "https://registry.npmjs.org/@oxc-project/types/-/types-0.124.0.tgz",
"integrity": "sha512-4n91DKnebUS4yjUHl2g3/b2T+IUdCfmoZGhmwsovZCDaJSs+QkVAM+0AqqTxHSsHfeiMuueT75cZaZcT/m0pSw==", "integrity": "sha512-VBFWMTBvHxS11Z5Lvlr3IWgrwhMTXV+Md+EQF0Xf60+wAdsGFTBx7X7K/hP4pi8N7dcm1RvcHwDxZ16Qx8keUg==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"funding": { "funding": {
@@ -158,9 +150,9 @@
} }
}, },
"node_modules/@rolldown/binding-android-arm64": { "node_modules/@rolldown/binding-android-arm64": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-android-arm64/-/binding-android-arm64-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-android-arm64/-/binding-android-arm64-1.0.0-rc.15.tgz",
"integrity": "sha512-lcJL0bN5hpgJfSIz/8PIf02irmyL43P+j1pTCfbD1DbLkmGRuFIA4DD3B3ZOvGqG0XiVvRznbKtN0COQVaKUTg==", "integrity": "sha512-YYe6aWruPZDtHNpwu7+qAHEMbQ/yRl6atqb/AhznLTnD3UY99Q1jE7ihLSahNWkF4EqRPVC4SiR4O0UkLK02tA==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -175,9 +167,9 @@
} }
}, },
"node_modules/@rolldown/binding-darwin-arm64": { "node_modules/@rolldown/binding-darwin-arm64": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-darwin-arm64/-/binding-darwin-arm64-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-darwin-arm64/-/binding-darwin-arm64-1.0.0-rc.15.tgz",
"integrity": "sha512-J7Zk3kLYFsLtuH6U+F4pS2sYVzac0qkjcO5QxHS7OS7yZu2LRs+IXo+uvJ/mvpyUljDJ3LROZPoQfgBIpCMhdQ==", "integrity": "sha512-oArR/ig8wNTPYsXL+Mzhs0oxhxfuHRfG7Ikw7jXsw8mYOtk71W0OkF2VEVh699pdmzjPQsTjlD1JIOoHkLP1Fg==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -192,9 +184,9 @@
} }
}, },
"node_modules/@rolldown/binding-darwin-x64": { "node_modules/@rolldown/binding-darwin-x64": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-darwin-x64/-/binding-darwin-x64-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-darwin-x64/-/binding-darwin-x64-1.0.0-rc.15.tgz",
"integrity": "sha512-iwtmmghy8nhfRGeNAIltcNXzD0QMNaaA5U/NyZc1Ia4bxrzFByNMDoppoC+hl7cDiUq5/1CnFthpT9n+UtfFyg==", "integrity": "sha512-YzeVqOqjPYvUbJSWJ4EDL8ahbmsIXQpgL3JVipmN+MX0XnXMeWomLN3Fb+nwCmP/jfyqte5I3XRSm7OfQrbyxw==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -209,9 +201,9 @@
} }
}, },
"node_modules/@rolldown/binding-freebsd-x64": { "node_modules/@rolldown/binding-freebsd-x64": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-freebsd-x64/-/binding-freebsd-x64-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-freebsd-x64/-/binding-freebsd-x64-1.0.0-rc.15.tgz",
"integrity": "sha512-DLFYI78SCiZr5VvdEplsVC2Vx53lnA4/Ga5C65iyldMVaErr86aiqCoNBLl92PXPfDtUYjUh+xFFor40ueNs4Q==", "integrity": "sha512-9Erhx956jeQ0nNTyif1+QWAXDRD38ZNjr//bSHrt6wDwB+QkAfl2q6Mn1k6OBPerznjRmbM10lgRb1Pli4xZPw==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -226,9 +218,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-arm-gnueabihf": { "node_modules/@rolldown/binding-linux-arm-gnueabihf": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm-gnueabihf/-/binding-linux-arm-gnueabihf-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm-gnueabihf/-/binding-linux-arm-gnueabihf-1.0.0-rc.15.tgz",
"integrity": "sha512-CsjTmTwd0Hri6iTw/DRMK7kOZ7FwAkrO4h8YWKoX/kcj833e4coqo2wzIFywtch/8Eb5enQ/lwLM7w6JX1W5RQ==", "integrity": "sha512-cVwk0w8QbZJGTnP/AHQBs5yNwmpgGYStL88t4UIaqcvYJWBfS0s3oqVLZPwsPU6M0zlW4GqjP0Zq5MnAGwFeGA==",
"cpu": [ "cpu": [
"arm" "arm"
], ],
@@ -243,9 +235,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-arm64-gnu": { "node_modules/@rolldown/binding-linux-arm64-gnu": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm64-gnu/-/binding-linux-arm64-gnu-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm64-gnu/-/binding-linux-arm64-gnu-1.0.0-rc.15.tgz",
"integrity": "sha512-2x9O2JbSPxpxMDhP9Z74mahAStibTlrBMW0520+epJH5sac7/LwZW5Bmg/E6CXuEF53JJFW509uP+lSedaUNxg==", "integrity": "sha512-eBZ/u8iAK9SoHGanqe/jrPnY0JvBN6iXbVOsbO38mbz+ZJsaobExAm1Iu+rxa4S1l2FjG0qEZn4Rc6X8n+9M+w==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -260,9 +252,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-arm64-musl": { "node_modules/@rolldown/binding-linux-arm64-musl": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm64-musl/-/binding-linux-arm64-musl-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm64-musl/-/binding-linux-arm64-musl-1.0.0-rc.15.tgz",
"integrity": "sha512-JA1QRW31ogheAIRhIg9tjMfsYbglXXYGNPLdPEYrwFxdbkQCAzvpSCSHCDWNl4hTtrol8WeboCSEpjdZK8qrCg==", "integrity": "sha512-ZvRYMGrAklV9PEkgt4LQM6MjQX2P58HPAuecwYObY2DhS2t35R0I810bKi0wmaYORt6m/2Sm+Z+nFgb0WhXNcQ==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -277,9 +269,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-ppc64-gnu": { "node_modules/@rolldown/binding-linux-ppc64-gnu": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-ppc64-gnu/-/binding-linux-ppc64-gnu-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-ppc64-gnu/-/binding-linux-ppc64-gnu-1.0.0-rc.15.tgz",
"integrity": "sha512-aOKU9dJheda8Kj8Y3w9gnt9QFOO+qKPAl8SWd7JPHP+Cu0EuDAE5wokQubLzIDQWg2myXq2XhTpOVS07qqvT+w==", "integrity": "sha512-VDpgGBzgfg5hLg+uBpCLoFG5kVvEyafmfxGUV0UHLcL5irxAK7PKNeC2MwClgk6ZAiNhmo9FLhRYgvMmedLtnQ==",
"cpu": [ "cpu": [
"ppc64" "ppc64"
], ],
@@ -294,9 +286,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-s390x-gnu": { "node_modules/@rolldown/binding-linux-s390x-gnu": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-s390x-gnu/-/binding-linux-s390x-gnu-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-s390x-gnu/-/binding-linux-s390x-gnu-1.0.0-rc.15.tgz",
"integrity": "sha512-OalO94fqj7IWRn3VdXWty75jC5dk4C197AWEuMhIpvVv2lw9fiPhud0+bW2ctCxb3YoBZor71QHbY+9/WToadA==", "integrity": "sha512-y1uXY3qQWCzcPgRJATPSOUP4tCemh4uBdY7e3EZbVwCJTY3gLJWnQABgeUetvED+bt1FQ01OeZwvhLS2bpNrAQ==",
"cpu": [ "cpu": [
"s390x" "s390x"
], ],
@@ -311,9 +303,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-x64-gnu": { "node_modules/@rolldown/binding-linux-x64-gnu": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-x64-gnu/-/binding-linux-x64-gnu-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-x64-gnu/-/binding-linux-x64-gnu-1.0.0-rc.15.tgz",
"integrity": "sha512-cVEl1vZtBsBZna3YMjGXNvnYYrOJ7RzuWvZU0ffvJUexWkukMaDuGhUXn0rjnV0ptzGVkvc+vW9Yqy6h8YX4pg==", "integrity": "sha512-023bTPBod7J3Y/4fzAN6QtpkSABR0rigtrwaP+qSEabUh5zf6ELr9Nc7GujaROuPY3uwdSIXWrvhn1KxOvurWA==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -328,9 +320,9 @@
} }
}, },
"node_modules/@rolldown/binding-linux-x64-musl": { "node_modules/@rolldown/binding-linux-x64-musl": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-linux-x64-musl/-/binding-linux-x64-musl-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-x64-musl/-/binding-linux-x64-musl-1.0.0-rc.15.tgz",
"integrity": "sha512-UzYnKCIIc4heAKgI4PZ3dfBGUZefGCJ1TPDuLHoCzgrMYPb5Rv6TLFuYtyM4rWyHM7hymNdsg5ik2C+UD9VDbA==", "integrity": "sha512-witB2O0/hU4CgfOOKUoeFgQ4GktPi1eEbAhaLAIpgD6+ZnhcPkUtPsoKKHRzmOoWPZue46IThdSgdo4XneOLYw==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -345,9 +337,9 @@
} }
}, },
"node_modules/@rolldown/binding-openharmony-arm64": { "node_modules/@rolldown/binding-openharmony-arm64": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-openharmony-arm64/-/binding-openharmony-arm64-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-openharmony-arm64/-/binding-openharmony-arm64-1.0.0-rc.15.tgz",
"integrity": "sha512-+6zoiF+RRyf5cdlFQP7nm58mq7+/2PFaY2DNQeD4B87N36JzfF/l9mdBkkmTvSYcYPE8tMh/o3cRlsx1ldLfog==", "integrity": "sha512-UCL68NJ0Ud5zRipXZE9dF5PmirzJE4E4BCIOOssEnM7wLDsxjc6Qb0sGDxTNRTP53I6MZpygyCpY8Aa8sPfKPg==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -362,9 +354,9 @@
} }
}, },
"node_modules/@rolldown/binding-wasm32-wasi": { "node_modules/@rolldown/binding-wasm32-wasi": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-wasm32-wasi/-/binding-wasm32-wasi-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-wasm32-wasi/-/binding-wasm32-wasi-1.0.0-rc.15.tgz",
"integrity": "sha512-rgFN6sA/dyebil3YTlL2evvi/M+ivhfnyxec7AccTpRPccno/rPoNlqybEZQBkcbZu8Hy+eqNJCqfBR8P7Pg8g==", "integrity": "sha512-ApLruZq/ig+nhaE7OJm4lDjayUnOHVUa77zGeqnqZ9pn0ovdVbbNPerVibLXDmWeUZXjIYIT8V3xkT58Rm9u5Q==",
"cpu": [ "cpu": [
"wasm32" "wasm32"
], ],
@@ -372,16 +364,18 @@
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
"dependencies": { "dependencies": {
"@napi-rs/wasm-runtime": "^1.1.1" "@emnapi/core": "1.9.2",
"@emnapi/runtime": "1.9.2",
"@napi-rs/wasm-runtime": "^1.1.3"
}, },
"engines": { "engines": {
"node": ">=14.0.0" "node": ">=14.0.0"
} }
}, },
"node_modules/@rolldown/binding-win32-arm64-msvc": { "node_modules/@rolldown/binding-win32-arm64-msvc": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-win32-arm64-msvc/-/binding-win32-arm64-msvc-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-win32-arm64-msvc/-/binding-win32-arm64-msvc-1.0.0-rc.15.tgz",
"integrity": "sha512-lHVNUG/8nlF1IQk1C0Ci574qKYyty2goMiPlRqkC5R+3LkXDkL5Dhx8ytbxq35m+pkHVIvIxviD+TWLdfeuadA==", "integrity": "sha512-KmoUoU7HnN+Si5YWJigfTws1jz1bKBYDQKdbLspz0UaqjjFkddHsqorgiW1mxcAj88lYUE6NC/zJNwT+SloqtA==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -396,9 +390,9 @@
} }
}, },
"node_modules/@rolldown/binding-win32-x64-msvc": { "node_modules/@rolldown/binding-win32-x64-msvc": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/binding-win32-x64-msvc/-/binding-win32-x64-msvc-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/binding-win32-x64-msvc/-/binding-win32-x64-msvc-1.0.0-rc.15.tgz",
"integrity": "sha512-G0oA4+w1iY5AGi5HcDTxWsoxF509hrFIPB2rduV5aDqS9FtDg1CAfa7V34qImbjfhIcA8C+RekocJZA96EarwQ==", "integrity": "sha512-3P2A8L+x75qavWLe/Dll3EYBJLQmtkJN8rfh+U/eR3MqMgL/h98PhYI+JFfXuDPgPeCB7iZAKiqii5vqOvnA0g==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -413,9 +407,9 @@
} }
}, },
"node_modules/@rolldown/pluginutils": { "node_modules/@rolldown/pluginutils": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-rc.15.tgz",
"integrity": "sha512-w6oiRWgEBl04QkFZgmW+jnU1EC9b57Oihi2ot3HNWIQRqgHp5PnYDia5iZ5FF7rpa4EQdiqMDXjlqKGXBhsoXw==", "integrity": "sha512-UromN0peaE53IaBRe9W7CjrZgXl90fqGpK+mIZbA3qSTeYqg3pqpROBdIPvOG3F5ereDHNwoHBI2e50n1BDr1g==",
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
@@ -2794,9 +2788,9 @@
} }
}, },
"node_modules/postcss": { "node_modules/postcss": {
"version": "8.5.8", "version": "8.5.12",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.8.tgz", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
"integrity": "sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==", "integrity": "sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==",
"dev": true, "dev": true,
"funding": [ "funding": [
{ {
@@ -2972,14 +2966,14 @@
} }
}, },
"node_modules/rolldown": { "node_modules/rolldown": {
"version": "1.0.0-rc.9", "version": "1.0.0-rc.15",
"resolved": "https://registry.npmjs.org/rolldown/-/rolldown-1.0.0-rc.9.tgz", "resolved": "https://registry.npmjs.org/rolldown/-/rolldown-1.0.0-rc.15.tgz",
"integrity": "sha512-9EbgWge7ZH+yqb4d2EnELAntgPTWbfL8ajiTW+SyhJEC4qhBbkCKbqFV4Ge4zmu5ziQuVbWxb/XwLZ+RIO7E8Q==", "integrity": "sha512-Ff31guA5zT6WjnGp0SXw76X6hzGRk/OQq2hE+1lcDe+lJdHSgnSX6nK3erbONHyCbpSj9a9E+uX/OvytZoWp2g==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@oxc-project/types": "=0.115.0", "@oxc-project/types": "=0.124.0",
"@rolldown/pluginutils": "1.0.0-rc.9" "@rolldown/pluginutils": "1.0.0-rc.15"
}, },
"bin": { "bin": {
"rolldown": "bin/cli.mjs" "rolldown": "bin/cli.mjs"
@@ -2988,21 +2982,21 @@
"node": "^20.19.0 || >=22.12.0" "node": "^20.19.0 || >=22.12.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@rolldown/binding-android-arm64": "1.0.0-rc.9", "@rolldown/binding-android-arm64": "1.0.0-rc.15",
"@rolldown/binding-darwin-arm64": "1.0.0-rc.9", "@rolldown/binding-darwin-arm64": "1.0.0-rc.15",
"@rolldown/binding-darwin-x64": "1.0.0-rc.9", "@rolldown/binding-darwin-x64": "1.0.0-rc.15",
"@rolldown/binding-freebsd-x64": "1.0.0-rc.9", "@rolldown/binding-freebsd-x64": "1.0.0-rc.15",
"@rolldown/binding-linux-arm-gnueabihf": "1.0.0-rc.9", "@rolldown/binding-linux-arm-gnueabihf": "1.0.0-rc.15",
"@rolldown/binding-linux-arm64-gnu": "1.0.0-rc.9", "@rolldown/binding-linux-arm64-gnu": "1.0.0-rc.15",
"@rolldown/binding-linux-arm64-musl": "1.0.0-rc.9", "@rolldown/binding-linux-arm64-musl": "1.0.0-rc.15",
"@rolldown/binding-linux-ppc64-gnu": "1.0.0-rc.9", "@rolldown/binding-linux-ppc64-gnu": "1.0.0-rc.15",
"@rolldown/binding-linux-s390x-gnu": "1.0.0-rc.9", "@rolldown/binding-linux-s390x-gnu": "1.0.0-rc.15",
"@rolldown/binding-linux-x64-gnu": "1.0.0-rc.9", "@rolldown/binding-linux-x64-gnu": "1.0.0-rc.15",
"@rolldown/binding-linux-x64-musl": "1.0.0-rc.9", "@rolldown/binding-linux-x64-musl": "1.0.0-rc.15",
"@rolldown/binding-openharmony-arm64": "1.0.0-rc.9", "@rolldown/binding-openharmony-arm64": "1.0.0-rc.15",
"@rolldown/binding-wasm32-wasi": "1.0.0-rc.9", "@rolldown/binding-wasm32-wasi": "1.0.0-rc.15",
"@rolldown/binding-win32-arm64-msvc": "1.0.0-rc.9", "@rolldown/binding-win32-arm64-msvc": "1.0.0-rc.15",
"@rolldown/binding-win32-x64-msvc": "1.0.0-rc.9" "@rolldown/binding-win32-x64-msvc": "1.0.0-rc.15"
} }
}, },
"node_modules/sade": { "node_modules/sade": {
@@ -3416,17 +3410,16 @@
} }
}, },
"node_modules/vite": { "node_modules/vite": {
"version": "8.0.0", "version": "8.0.8",
"resolved": "https://registry.npmjs.org/vite/-/vite-8.0.0.tgz", "resolved": "https://registry.npmjs.org/vite/-/vite-8.0.8.tgz",
"integrity": "sha512-fPGaRNj9Zytaf8LEiBhY7Z6ijnFKdzU/+mL8EFBaKr7Vw1/FWcTBAMW0wLPJAGMPX38ZPVCVgLceWiEqeoqL2Q==", "integrity": "sha512-dbU7/iLVa8KZALJyLOBOQ88nOXtNG8vxKuOT4I2mD+Ya70KPceF4IAmDsmU0h1Qsn5bPrvsY9HJstCRh3hG6Uw==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@oxc-project/runtime": "0.115.0",
"lightningcss": "^1.32.0", "lightningcss": "^1.32.0",
"picomatch": "^4.0.3", "picomatch": "^4.0.4",
"postcss": "^8.5.8", "postcss": "^8.5.8",
"rolldown": "1.0.0-rc.9", "rolldown": "1.0.0-rc.15",
"tinyglobby": "^0.2.15" "tinyglobby": "^0.2.15"
}, },
"bin": { "bin": {
@@ -3443,8 +3436,8 @@
}, },
"peerDependencies": { "peerDependencies": {
"@types/node": "^20.19.0 || >=22.12.0", "@types/node": "^20.19.0 || >=22.12.0",
"@vitejs/devtools": "^0.0.0-alpha.31", "@vitejs/devtools": "^0.1.0",
"esbuild": "^0.27.0", "esbuild": "^0.27.0 || ^0.28.0",
"jiti": ">=1.21.0", "jiti": ">=1.21.0",
"less": "^4.0.0", "less": "^4.0.0",
"sass": "^1.70.0", "sass": "^1.70.0",
@@ -0,0 +1,97 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import { persistentStore } from "../stores/persistent";
import { calculateHistogramData } from "../lib/histogram";
import TokenHistogram from "./TokenHistogram.svelte";
const nf = new Intl.NumberFormat();
const histogramCollapsed = persistentStore<boolean>("activity-histogram-collapsed", false);
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.tokens.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.tokens.output_tokens, 0);
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.tokens.cache_tokens, 0);
const promptPerSecond = $metrics.filter((m) => m.tokens.prompt_per_second > 0).map((m) => m.tokens.prompt_per_second);
const tokensPerSecond = $metrics.filter((m) => m.tokens.tokens_per_second > 0).map((m) => m.tokens.tokens_per_second);
const promptHistogramData =
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
const genHistogramData =
tokensPerSecond.length > 0 ? calculateHistogramData(tokensPerSecond) : null;
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
totalCacheTokens,
inFlightRequests: $inFlightRequests,
promptHistogramData,
genHistogramData,
};
});
</script>
<div class="card relative p-3">
<button
class="absolute top-2 right-2 w-6 h-6 flex items-center justify-center rounded-full border border-gray-300 dark:border-gray-600 text-gray-400 dark:text-gray-500 hover:text-gray-600 dark:hover:text-gray-300 hover:border-gray-400 dark:hover:border-gray-400 transition-colors"
onclick={() => ($histogramCollapsed = !$histogramCollapsed)}
title={$histogramCollapsed ? "Show histograms" : "Hide histograms"}
>
{#if $histogramCollapsed}
<svg class="w-3.5 h-3.5" viewBox="0 0 16 16" fill="currentColor">
<path d="M4.5 6l3.5 4 3.5-4H4.5z" />
</svg>
{:else}
<svg class="w-3 h-3" viewBox="0 0 16 16" fill="currentColor">
<path d="M3.5 3.5l9 9M12.5 3.5l-9 9" stroke="currentColor" stroke-width="2" stroke-linecap="round" fill="none" />
</svg>
{/if}
</button>
{#if !$histogramCollapsed}
<div class="flex flex-col sm:flex-row gap-6 mb-3">
<div class="w-full sm:w-1/2 min-w-0">
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Prompt Processing</div>
{#if stats.promptHistogramData}
<TokenHistogram
data={stats.promptHistogramData}
unit="prompt tokens/sec"
colorClass="text-amber-500 dark:text-amber-400"
/>
{:else}
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No prompt speed data yet</div>
{/if}
</div>
<div class="w-full sm:w-1/2 min-w-0">
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Token Generation</div>
{#if stats.genHistogramData}
<TokenHistogram data={stats.genHistogramData} unit="tokens/sec" />
{:else}
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No generation speed data yet</div>
{/if}
</div>
</div>
{/if}
<div class="grid grid-cols-4 gap-x-6 gap-y-1 text-sm">
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Requests</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Cached</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Processed</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Generated</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalRequests)}</span> completed,
<span class="font-semibold">{nf.format(stats.inFlightRequests)}</span> waiting
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalCacheTokens)}</span> tokens
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalInputTokens)}</span> tokens
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalOutputTokens)}</span> tokens
</div>
</div>
</div>
@@ -106,6 +106,7 @@
const delta = parsed.choices?.[0]?.delta; const delta = parsed.choices?.[0]?.delta;
if (delta?.content) result.content += delta.content; if (delta?.content) result.content += delta.content;
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content; if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
if (delta?.reasoning) result.reasoning += delta.reasoning;
} catch { } catch {
// skip unparseable lines // skip unparseable lines
} }
+7 -1
View File
@@ -50,7 +50,7 @@
<a <a
href="/" href="/"
use:link use:link
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}" class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold underline underline-offset-4' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
> >
Playground Playground
</a> </a>
@@ -59,6 +59,8 @@
use:link use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/models", $currentRoute)} class:font-semibold={isActive("/models", $currentRoute)}
class:underline={isActive("/models", $currentRoute)}
class:underline-offset-4={isActive("/models", $currentRoute)}
> >
Models Models
</a> </a>
@@ -67,6 +69,8 @@
use:link use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/activity", $currentRoute)} class:font-semibold={isActive("/activity", $currentRoute)}
class:underline={isActive("/activity", $currentRoute)}
class:underline-offset-4={isActive("/activity", $currentRoute)}
> >
Activity Activity
</a> </a>
@@ -75,6 +79,8 @@
use:link use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/logs", $currentRoute)} class:font-semibold={isActive("/logs", $currentRoute)}
class:underline={isActive("/logs", $currentRoute)}
class:underline-offset-4={isActive("/logs", $currentRoute)}
> >
Logs Logs
</a> </a>
-167
View File
@@ -1,167 +0,0 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import TokenHistogram from "./TokenHistogram.svelte";
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
if (totalRequests === 0) {
return {
totalRequests: 0,
totalInputTokens: 0,
totalOutputTokens: 0,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
// Calculate token statistics using output_tokens and duration_ms
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
if (validMetrics.length === 0) {
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
// Calculate tokens/second for each valid metric
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
// Sort for percentile calculation
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
// Create histogram data
const min = Math.min(...tokensPerSecond);
const max = Math.max(...tokensPerSecond);
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
const binSize = (max - min) / binCount;
const bins = Array(binCount).fill(0);
tokensPerSecond.forEach((value) => {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
});
const histogramData: HistogramData = {
bins,
min,
max,
binSize,
p99,
p95,
p50,
};
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: {
p99: p99.toFixed(2),
p95: p95.toFixed(2),
p50: p50.toFixed(2),
},
histogramData,
};
});
const nf = new Intl.NumberFormat();
</script>
<div class="card">
<div class="rounded-lg overflow-hidden border border-card-border-inner">
<table class="min-w-full divide-y divide-card-border-inner">
<thead class="bg-secondary">
<tr>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Processed
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Generated
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Token Stats (tokens/sec)
</th>
</tr>
</thead>
<tbody class="bg-surface divide-y divide-card-border-inner">
<tr class="hover:bg-secondary">
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
<div class="flex flex-col gap-1">
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
<div class="space-y-3">
<div class="grid grid-cols-3 gap-2 items-center">
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p50}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p95}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p99}
</div>
</div>
</div>
{#if stats.histogramData}
<TokenHistogram data={stats.histogramData} />
{/if}
</div>
</td>
</tr>
</tbody>
</table>
</div>
</div>
+41 -25
View File
@@ -1,23 +1,19 @@
<script lang="ts"> <script lang="ts">
interface HistogramData { import type { HistogramData } from "../lib/types";
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
interface Props { let {
data,
unit = "tokens/sec",
colorClass = "text-blue-500 dark:text-blue-400",
}: {
data: HistogramData; data: HistogramData;
} unit?: string;
colorClass?: string;
} = $props();
let { data }: Props = $props(); const height = 250;
const padding = { top: 30, right: 20, bottom: 40, left: 75 };
const height = 120; const viewBoxWidth = 1200;
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
const viewBoxWidth = 600;
const chartWidth = viewBoxWidth - padding.left - padding.right; const chartWidth = viewBoxWidth - padding.left - padding.right;
const chartHeight = height - padding.top - padding.bottom; const chartHeight = height - padding.top - padding.bottom;
@@ -43,6 +39,24 @@
opacity="0.3" opacity="0.3"
/> />
<!-- Y-axis ticks and labels -->
{#each [0, 0.5, 1] as fraction}
{@const tickCount = Math.round(maxCount * fraction)}
{@const tickY = height - padding.bottom - fraction * chartHeight}
<line
x1={padding.left - 8}
y1={tickY}
x2={padding.left}
y2={tickY}
stroke="currentColor"
stroke-width="1"
opacity="0.4"
/>
<text x={padding.left - 10} y={tickY + 10} font-size="26" fill="currentColor" opacity="0.8" text-anchor="end">
{tickCount}
</text>
{/each}
<!-- X-axis --> <!-- X-axis -->
<line <line
x1={padding.left} x1={padding.left}
@@ -69,9 +83,9 @@
height={barHeight} height={barHeight}
fill="currentColor" fill="currentColor"
opacity="0.6" opacity="0.6"
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer" class="{colorClass} hover:opacity-90 transition-opacity cursor-pointer"
/> />
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title> <title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} ${unit}\nCount: ${count}`}</title>
</g> </g>
{/each} {/each}
@@ -113,17 +127,19 @@
/> />
<!-- X-axis labels --> <!-- X-axis labels -->
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start"> <text x={padding.left} y={height - 8} font-size="26" fill="currentColor" opacity="0.8" text-anchor="start">
{data.min.toFixed(1)} {data.min.toFixed(1)}
</text> </text>
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end"> <text
x={viewBoxWidth - padding.right}
y={height - 8}
font-size="26"
fill="currentColor"
opacity="0.8"
text-anchor="end"
>
{data.max.toFixed(1)} {data.max.toFixed(1)}
</text> </text>
<!-- X-axis label -->
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
Tokens/Second Distribution
</text>
</svg> </svg>
</div> </div>
+1 -1
View File
@@ -25,7 +25,7 @@ function parseSSELine(line: string): StreamChunk | null {
const parsed = JSON.parse(data); const parsed = JSON.parse(data);
const delta = parsed.choices?.[0]?.delta; const delta = parsed.choices?.[0]?.delta;
const content = delta?.content || ""; const content = delta?.content || "";
const reasoning_content = delta?.reasoning_content || ""; const reasoning_content = delta?.reasoning_content || delta?.reasoning || "";
if (content || reasoning_content) { if (content || reasoning_content) {
return { content, reasoning_content, done: false }; return { content, reasoning_content, done: false };
+167
View File
@@ -0,0 +1,167 @@
import { describe, it, expect } from "vitest";
import { calculateHistogramData } from "./histogram";
describe("calculateHistogramData", () => {
describe("edge cases", () => {
it("returns null for empty input", () => {
expect(calculateHistogramData([])).toBeNull();
});
it("handles single value", () => {
const result = calculateHistogramData([42]);
expect(result).not.toBeNull();
expect(result!.bins).toEqual([1]);
expect(result!.min).toBe(42);
expect(result!.max).toBe(42);
expect(result!.binSize).toBe(0);
expect(result!.p50).toBe(42);
expect(result!.p95).toBe(42);
expect(result!.p99).toBe(42);
});
it("handles all identical values", () => {
const result = calculateHistogramData([10, 10, 10, 10, 10]);
expect(result).not.toBeNull();
expect(result!.bins).toEqual([5]);
expect(result!.min).toBe(10);
expect(result!.max).toBe(10);
expect(result!.binSize).toBe(0);
});
it("handles two distinct values", () => {
const result = calculateHistogramData([10, 20]);
expect(result).not.toBeNull();
expect(result!.min).toBe(10);
expect(result!.max).toBe(20);
expect(result!.p50).toBe(15);
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(2);
});
});
describe("bin distribution", () => {
it("bins sum to total number of values", () => {
const values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(values.length);
});
it("distributes uniform values across bins", () => {
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.bins.length).toBe(8);
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(100);
});
it("places values in correct bins", () => {
const values = [1, 1, 1, 5, 5, 9, 9, 9];
const result = calculateHistogramData(values, { minBins: 3, maxBins: 3 });
expect(result).not.toBeNull();
expect(result!.bins.length).toBe(3);
expect(result!.bins.reduce((s, b) => s + b, 0)).toBe(8);
});
it("handles skewed distribution", () => {
const values = [1, 1, 1, 1, 1, 100];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(6);
});
});
describe("percentiles", () => {
it("calculates correct p50 for even-length array", () => {
const values = [1, 2, 3, 4];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBe(2.5);
});
it("calculates correct p50 for odd-length array", () => {
const values = [1, 2, 3, 4, 5];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBe(3);
});
it("calculates p99 with interpolation", () => {
const values = Array.from({ length: 100 }, (_, i) => i + 1);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p99).toBeCloseTo(99.01);
});
it("calculates p95 with interpolation", () => {
const values = Array.from({ length: 100 }, (_, i) => i + 1);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p95).toBeCloseTo(95.05);
});
it("percentiles are monotonically increasing", () => {
const values = Array.from({ length: 200 }, () => Math.random() * 100);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBeLessThanOrEqual(result!.p95);
expect(result!.p95).toBeLessThanOrEqual(result!.p99);
});
});
describe("bin count adaptation", () => {
it("uses minimum bins for small datasets", () => {
// n=8: sturges=4, clamped up to minBins=5
const values = Array.from({ length: 8 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(5);
});
it("scales bins with dataset size", () => {
// n=100: sturges=8
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(8);
});
it("caps bins at maximum", () => {
// n=1000: sturges=11, clamped down to maxBins=10
const values = Array.from({ length: 1000 }, (_, i) => i);
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10 });
expect(result!.bins.length).toBe(10);
});
it("respects custom options", () => {
// n=100: sturges=8, within [minBins=5, maxBins=10]
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10 });
expect(result!.bins.length).toBe(8);
});
});
describe("min and max", () => {
it("correctly identifies min and max", () => {
const values = [5, 3, 8, 1, 9, 2];
const result = calculateHistogramData(values);
expect(result!.min).toBe(1);
expect(result!.max).toBe(9);
});
it("handles negative values", () => {
const values = [-10, -5, 0, 5, 10];
const result = calculateHistogramData(values);
expect(result!.min).toBe(-10);
expect(result!.max).toBe(10);
});
it("handles floating point values", () => {
const values = [1.5, 2.7, 3.14, 0.5, 4.99];
const result = calculateHistogramData(values);
expect(result!.min).toBe(0.5);
expect(result!.max).toBe(4.99);
});
});
});
+71
View File
@@ -0,0 +1,71 @@
import type { HistogramData } from "./types";
export interface HistogramOptions {
minBins?: number;
maxBins?: number;
}
const DEFAULT_OPTIONS: HistogramOptions = {
minBins: 5,
maxBins: 20,
};
function percentile(sorted: number[], p: number): number {
if (sorted.length === 0) return 0;
if (sorted.length === 1) return sorted[0];
const rank = (p / 100) * (sorted.length - 1);
const lower = Math.floor(rank);
const upper = Math.ceil(rank);
const fraction = rank - lower;
return sorted[lower] + fraction * (sorted[upper] - sorted[lower]);
}
export function calculateHistogramData(
values: number[],
options: HistogramOptions = DEFAULT_OPTIONS,
): HistogramData | null {
if (values.length === 0) return null;
const sorted = [...values].sort((a, b) => a - b);
const min = sorted[0];
const max = sorted[sorted.length - 1];
const p50 = percentile(sorted, 50);
const p95 = percentile(sorted, 95);
const p99 = percentile(sorted, 99);
if (min === max) {
return {
bins: [values.length],
min,
max,
binSize: 0,
p50,
p95,
p99,
};
}
const { minBins = 5, maxBins = 20 } = options;
const sturges = Math.ceil(Math.log2(values.length)) + 1;
const binCount = Math.min(maxBins, Math.max(minBins, sturges));
const binSize = (max - min) / binCount;
const bins = new Array(binCount).fill(0);
for (const value of values) {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
}
return {
bins,
min,
max,
binSize,
p50,
p95,
p99,
};
}
+21 -4
View File
@@ -12,15 +12,22 @@ export interface Model {
aliases?: string[]; aliases?: string[];
} }
export interface Metrics { export interface TokenMetrics {
id: number;
timestamp: string;
model: string;
cache_tokens: number; cache_tokens: number;
input_tokens: number; input_tokens: number;
output_tokens: number; output_tokens: number;
prompt_per_second: number; prompt_per_second: number;
tokens_per_second: number; tokens_per_second: number;
}
export interface ActivityLogEntry {
id: number;
timestamp: string;
model: string;
req_path: string;
resp_content_type: string;
resp_status_code: number;
tokens: TokenMetrics;
duration_ms: number; duration_ms: number;
has_capture: boolean; has_capture: boolean;
} }
@@ -48,6 +55,16 @@ export interface APIEventEnvelope {
data: string; data: string;
} }
export interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
export interface VersionInfo { export interface VersionInfo {
build_date: string; build_date: string;
commit: string; commit: string;
+215 -39
View File
@@ -1,9 +1,89 @@
<script lang="ts"> <script lang="ts">
import { metrics, getCapture } from "../stores/api"; import { metrics, getCapture } from "../stores/api";
import ActivityStats from "../components/ActivityStats.svelte";
import Tooltip from "../components/Tooltip.svelte"; import Tooltip from "../components/Tooltip.svelte";
import CaptureDialog from "../components/CaptureDialog.svelte"; import CaptureDialog from "../components/CaptureDialog.svelte";
import { persistentStore } from "../stores/persistent";
import { onMount } from "svelte";
import type { ReqRespCapture } from "../lib/types"; import type { ReqRespCapture } from "../lib/types";
type ColumnKey =
| "id"
| "time"
| "model"
| "req_path"
| "resp_status_code"
| "resp_content_type"
| "cached"
| "prompt"
| "generated"
| "prompt_speed"
| "gen_speed"
| "duration"
| "capture";
interface ColumnDef {
key: ColumnKey;
label: string;
defaultVisible: boolean;
}
const columns: ColumnDef[] = [
{ key: "id", label: "ID", defaultVisible: true },
{ key: "time", label: "Time", defaultVisible: true },
{ key: "model", label: "Model", defaultVisible: true },
{ key: "req_path", label: "Path", defaultVisible: false },
{ key: "resp_status_code", label: "Status", defaultVisible: false },
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
{ key: "cached", label: "Cached", defaultVisible: true },
{ key: "prompt", label: "Prompt", defaultVisible: true },
{ key: "generated", label: "Generated", defaultVisible: true },
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
{ key: "duration", label: "Duration", defaultVisible: true },
{ key: "capture", label: "Capture", defaultVisible: true },
];
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
const visibleColumns = persistentStore<ColumnKey[]>(
"activity-columns",
defaultVisibleKeys
);
let columnsMenuOpen = $state(false);
let dropdownContainer: HTMLDivElement | null = null;
onMount(() => {
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape" && columnsMenuOpen) {
columnsMenuOpen = false;
}
}
function handleClick(e: MouseEvent) {
if (columnsMenuOpen && dropdownContainer && !dropdownContainer.contains(e.target as Node)) {
columnsMenuOpen = false;
}
}
document.addEventListener("keydown", handleKeydown);
document.addEventListener("click", handleClick);
return () => {
document.removeEventListener("keydown", handleKeydown);
document.removeEventListener("click", handleClick);
};
});
function toggleColumn(key: ColumnKey) {
const current = $visibleColumns;
if (current.includes(key)) {
if (current.length > 1) {
visibleColumns.set(current.filter((k) => k !== key));
}
} else {
visibleColumns.set([...current, key]);
}
}
function formatSpeed(speed: number): string { function formatSpeed(speed: number): string {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s"; return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
} }
@@ -62,64 +142,160 @@
</script> </script>
<div class="p-2"> <div class="p-2">
<h1 class="text-2xl font-bold">Activity</h1> <div class="mt-4 mb-4">
<ActivityStats />
</div>
{#if $metrics.length === 0} <div class="card overflow-auto relative min-h-[30rem]">
<div class="text-center py-8"> <div class="flex justify-end px-4" bind:this={dropdownContainer}>
<p class="text-gray-600">No metrics data available</p> <div class="relative">
<button
class="w-8 h-8 flex items-center justify-center rounded hover:bg-secondary-hover transition-colors"
onclick={() => (columnsMenuOpen = !columnsMenuOpen)}
title="Select columns"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6V4m0 2a2 2 0 100 4m0-4a2 2 0 110 4m-6 8a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4m6 6v10m6-2a2 2 0 100-4m0 4a2 2 0 110-4m0 4v2m0-6V4"></path>
</svg>
</button>
{#if columnsMenuOpen}
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]">
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10">
Columns
</div>
{#each columns as col (col.key)}
<label
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors"
>
<input
type="checkbox"
checked={$visibleColumns.includes(col.key)}
onchange={() => toggleColumn(col.key)}
class="rounded"
/>
{col.label}
</label>
{/each}
</div>
{/if}
</div>
</div> </div>
{:else}
<div class="card overflow-auto"> <table class="min-w-full divide-y">
<table class="min-w-full divide-y"> <thead class="border-gray-200 dark:border-white/10">
<thead class="border-gray-200 dark:border-white/10"> <tr class="text-left text-xs uppercase tracking-wider">
<tr class="text-left text-xs uppercase tracking-wider"> {#if $visibleColumns.includes("id")}
<th class="px-6 py-3">ID</th> <th class="px-6 py-3">ID</th>
{/if}
{#if $visibleColumns.includes("time")}
<th class="px-6 py-3">Time</th> <th class="px-6 py-3">Time</th>
{/if}
{#if $visibleColumns.includes("model")}
<th class="px-6 py-3">Model</th> <th class="px-6 py-3">Model</th>
{/if}
{#if $visibleColumns.includes("req_path")}
<th class="px-6 py-3">Path</th>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<th class="px-6 py-3">Status</th>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<th class="px-6 py-3">Content-Type</th>
{/if}
{#if $visibleColumns.includes("cached")}
<th class="px-6 py-3"> <th class="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" /> Cached <Tooltip content="prompt tokens from cache" />
</th> </th>
{/if}
{#if $visibleColumns.includes("prompt")}
<th class="px-6 py-3"> <th class="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" /> Prompt <Tooltip content="new prompt tokens processed" />
</th> </th>
{/if}
{#if $visibleColumns.includes("generated")}
<th class="px-6 py-3">Generated</th> <th class="px-6 py-3">Generated</th>
<th class="px-6 py-3">Prompt Processing</th> {/if}
<th class="px-6 py-3">Generation Speed</th> {#if $visibleColumns.includes("prompt_speed")}
<th class="px-6 py-3">Prompt Speed</th>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<th class="px-6 py-3">Gen Speed</th>
{/if}
{#if $visibleColumns.includes("duration")}
<th class="px-6 py-3">Duration</th> <th class="px-6 py-3">Duration</th>
{/if}
{#if $visibleColumns.includes("capture")}
<th class="px-6 py-3">Capture</th> <th class="px-6 py-3">Capture</th>
{/if}
</tr>
</thead>
<tbody class="divide-y">
{#if sortedMetrics.length === 0}
<tr>
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
No activity recorded
</td>
</tr> </tr>
</thead> {:else}
<tbody class="divide-y">
{#each sortedMetrics as metric (metric.id)} {#each sortedMetrics as metric (metric.id)}
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10"> <tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
<td class="px-4 py-4">{metric.id + 1}</td> {#if $visibleColumns.includes("id")}
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td> <td class="px-4 py-4">{metric.id + 1}</td>
<td class="px-6 py-4">{metric.model}</td> {/if}
<td class="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td> {#if $visibleColumns.includes("time")}
<td class="px-6 py-4">{metric.input_tokens.toLocaleString()}</td> <td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
<td class="px-6 py-4">{metric.output_tokens.toLocaleString()}</td> {/if}
<td class="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td> {#if $visibleColumns.includes("model")}
<td class="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td> <td class="px-6 py-4">{metric.model}</td>
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td> {/if}
<td class="px-6 py-4"> {#if $visibleColumns.includes("req_path")}
{#if metric.has_capture} <td class="px-6 py-4">{metric.req_path || "-"}</td>
<button {/if}
onclick={() => viewCapture(metric.id)} {#if $visibleColumns.includes("resp_status_code")}
disabled={loadingCaptureId === metric.id} <td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
class="btn btn--sm" {/if}
> {#if $visibleColumns.includes("resp_content_type")}
{loadingCaptureId === metric.id ? "..." : "View"} <td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
</button> {/if}
{:else} {#if $visibleColumns.includes("cached")}
<span class="text-txtsecondary">-</span> <td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
{/if} {/if}
</td> {#if $visibleColumns.includes("prompt")}
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("generated")}
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
{/if}
{#if $visibleColumns.includes("duration")}
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
{/if}
{#if $visibleColumns.includes("capture")}
<td class="px-6 py-4">
{#if metric.has_capture}
<button
onclick={() => viewCapture(metric.id)}
disabled={loadingCaptureId === metric.id}
class="btn btn--sm"
>
{loadingCaptureId === metric.id ? "..." : "View"}
</button>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
</td>
{/if}
</tr> </tr>
{/each} {/each}
</tbody> {/if}
</table> </tbody>
</div> </table>
{/if} </div>
</div> </div>
<CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} /> <CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} />
+23 -37
View File
@@ -10,49 +10,35 @@
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels"); const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
let direction = $derived<"horizontal" | "vertical">( let direction = $derived<"horizontal" | "vertical">(
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal" $screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal",
); );
function cycleViewMode(): void {
const modes: ViewMode[] = ["panels", "proxy", "upstream"];
const currentIndex = modes.indexOf($viewModeStore);
const nextIndex = (currentIndex + 1) % modes.length;
viewModeStore.set(modes[nextIndex]);
}
function getViewModeIcon(mode: ViewMode): string {
switch (mode) {
case "proxy":
return "P";
case "upstream":
return "U";
case "panels":
return "⊞";
}
}
function getViewModeLabel(mode: ViewMode): string {
switch (mode) {
case "proxy":
return "Proxy";
case "upstream":
return "Upstream";
case "panels":
return "Panels";
}
}
</script> </script>
<div class="flex flex-col h-full w-full gap-2"> <div class="flex flex-col h-full w-full gap-2">
<div class="flex items-center gap-2"> <div class="flex items-center gap-1">
<button <button
onclick={cycleViewMode} onclick={() => viewModeStore.set("panels")}
class="btn flex items-center gap-2 text-sm" class:btn={true}
title="Toggle view mode" class:bg-primary={$viewModeStore === "panels"}
aria-label="Toggle view mode: {getViewModeLabel($viewModeStore)}" class:text-btn-primary-text={$viewModeStore === "panels"}
> >
<span class="font-mono font-bold">{getViewModeIcon($viewModeStore)}</span> Both
<span>{getViewModeLabel($viewModeStore)}</span> </button>
<button
onclick={() => viewModeStore.set("proxy")}
class:btn={true}
class:bg-primary={$viewModeStore === "proxy"}
class:text-btn-primary-text={$viewModeStore === "proxy"}
>
Proxy
</button>
<button
onclick={() => viewModeStore.set("upstream")}
class:btn={true}
class:bg-primary={$viewModeStore === "upstream"}
class:text-btn-primary-text={$viewModeStore === "upstream"}
>
Upstream
</button> </button>
</div> </div>
+1 -9
View File
@@ -2,7 +2,6 @@
import { isNarrow } from "../stores/theme"; import { isNarrow } from "../stores/theme";
import { upstreamLogs } from "../stores/api"; import { upstreamLogs } from "../stores/api";
import ModelsPanel from "../components/ModelsPanel.svelte"; import ModelsPanel from "../components/ModelsPanel.svelte";
import StatsPanel from "../components/StatsPanel.svelte";
import LogPanel from "../components/LogPanel.svelte"; import LogPanel from "../components/LogPanel.svelte";
import ResizablePanels from "../components/ResizablePanels.svelte"; import ResizablePanels from "../components/ResizablePanels.svelte";
@@ -14,13 +13,6 @@
<ModelsPanel /> <ModelsPanel />
{/snippet} {/snippet}
{#snippet rightPanel()} {#snippet rightPanel()}
<div class="flex flex-col h-full space-y-4"> <LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
{#if direction === "horizontal"}
<StatsPanel />
{/if}
<div class="flex-1 min-h-0">
<LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
</div>
</div>
{/snippet} {/snippet}
</ResizablePanels> </ResizablePanels>
+12 -4
View File
@@ -1,5 +1,13 @@
import { writable } from "svelte/store"; import { writable } from "svelte/store";
import type { Model, Metrics, VersionInfo, LogData, APIEventEnvelope, ReqRespCapture, InFlightStats } from "../lib/types"; import type {
Model,
ActivityLogEntry,
VersionInfo,
LogData,
APIEventEnvelope,
ReqRespCapture,
InFlightStats,
} from "../lib/types";
import { connectionState } from "./theme"; import { connectionState } from "./theme";
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
@@ -8,7 +16,7 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
export const models = writable<Model[]>([]); export const models = writable<Model[]>([]);
export const proxyLogs = writable<string>(""); export const proxyLogs = writable<string>("");
export const upstreamLogs = writable<string>(""); export const upstreamLogs = writable<string>("");
export const metrics = writable<Metrics[]>([]); export const metrics = writable<ActivityLogEntry[]>([]);
export const inFlightRequests = writable<number>(0); export const inFlightRequests = writable<number>(0);
export const versionInfo = writable<VersionInfo>({ export const versionInfo = writable<VersionInfo>({
build_date: "unknown", build_date: "unknown",
@@ -62,7 +70,7 @@ export function enableAPIEvents(enabled: boolean): void {
const newModels = JSON.parse(message.data) as Model[]; const newModels = JSON.parse(message.data) as Model[];
// Sort models by name and id // Sort models by name and id
newModels.sort((a, b) => { newModels.sort((a, b) => {
return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric : true} ); return (a.name + a.id).localeCompare(b.name + b.id, undefined, { numeric: true });
}); });
models.set(newModels); models.set(newModels);
break; break;
@@ -82,7 +90,7 @@ export function enableAPIEvents(enabled: boolean): void {
} }
case "metrics": { case "metrics": {
const newMetrics = JSON.parse(message.data) as Metrics[]; const newMetrics = JSON.parse(message.data) as ActivityLogEntry[];
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]); metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
break; break;
} }
+2 -1
View File
@@ -3,7 +3,8 @@ import { persistentStore } from "./persistent";
import type { ScreenWidth } from "../lib/types"; import type { ScreenWidth } from "../lib/types";
// Persistent stores // Persistent stores
export const isDarkMode = persistentStore<boolean>("theme", false); const systemDark = typeof window !== "undefined" && window.matchMedia("(prefers-color-scheme: dark)").matches;
export const isDarkMode = persistentStore<boolean>("theme", systemDark);
export const appTitle = persistentStore<string>("app-title", "llama-swap"); export const appTitle = persistentStore<string>("app-title", "llama-swap");
// Non-persistent stores // Non-persistent stores
+4
View File
@@ -26,6 +26,10 @@ export default defineConfig({
assetsDir: "assets", assetsDir: "assets",
}, },
server: { server: {
// yes very insecure but who's running this thing
// on the public internet for dev?! haha.
host: "0.0.0.0",
allowedHosts: true,
proxy: { proxy: {
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development "/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080", "/logs": "http://localhost:8080",