Compare commits

...

31 Commits

Author SHA1 Message Date
Benson Wong 29d3d9ba20 perf: add macOS GPU monitoring via mactop and ioreg (#816)
Implement performance monitoring on OSX for Apple Silicon hardware. 

The implementation uses a combination of mactop and ioreg. If mactop is
installed (`brew install mactop`) it is used in a headless cli mode to
stream usage metrics. mactop hooks into unpublished(?) C based APIs in
OSX. Rather than introduce a cgo dependency into llama-swap's build
chain only for darwin I opted to go the external process route.

ioreg, which comes bundled with OSX is used as the fallback. It does not
provide temperature and power usage data but is able to show accurate
GPU and memory utilization.

Updates #771, #814
2026-06-03 21:51:03 -07:00
Benson Wong 9be9a87fa0 internal/process: improve windows shutdown behaviour (#808)
Add Windows specific shutdown code paths so stopping of child processes
is more reliable:

- stopping llama-swap won't leave behind any child processes it created
- uses Job Objects in Windows so the whole llama-swap tree is closed by
the os
- add procCtx to baseRouter. It replaces shutdownCtx as a signal for
managing lifetime state.
- shutdownCtx is only used by the router to stop handling new requests
during shutdown
- improve debug logging to make it easier to trace source of issues

Fixes #804
Updates #807
2026-06-01 00:45:30 -07:00
Benson Wong 6ea551362e process,router: make model shutdown and load-streaming robust
Note: The original proxy/process_unix.go had a noop for setProcAttributes
so it also did not stop grandchildren processes. This patch adds that capability 
and improves reliability.

--

Stop() no longer hangs on a shell wrapper that forks the real binary.
The upstream is built with exec.CommandContext + cmd.Cancel +
cmd.WaitDelay, so cmd.Wait() returns even when a forked grandchild
inherits the stdout/stderr pipes. killProcess sends the stop signal
directly (not by cancelling the context) so cmd.WaitDelay measures from
process exit and never silently caps the caller's graceful timeout.

The upstream is also started in its own process group (Setpgid) on Unix,
so the graceful SIGTERM — and the SIGKILL escalation after the timeout —
are delivered to the whole group via the negative PID. A forked
grandchild is reaped with its parent instead of leaking as an orphan.

The loading-spinner SSE goroutine can no longer panic when it outlives
the request. net/http recycles the response writer via Reset(nil) once
ServeHTTP returns; the orphaned goroutine then flushed against a
nil-backed writer and crashed with a SIGSEGV. A release() fence on
loadingWriter lets any in-flight write finish then short-circuits later
writes/flushes, and all three ServeHTTP select branches run a
finishLoading helper (cancelLoad, waitForCompletion, release) before the
writer is reclaimed.

- internal/process: exec.CommandContext + WaitDelay, Setpgid process
groups, group-wide SIGTERM/SIGKILL teardown
- internal/router: release() fence + finishLoading on loadingWriter

fixes #804
2026-05-31 10:11:12 -07:00
Benson Wong 03d58e53fa Add load testing tool to the UI (#805)
Wouldn't it be nice to test the performance, swapping and concurrency
from the UI? Now we can! This is a port of `cmd/test-concurrency` into the UI

Here's a demo of it working with a swap matrix: 

https://github.com/user-attachments/assets/b6bb12ec-0381-46f1-a6b8-27d1c3c0ddb3
2026-05-30 17:04:30 -07:00
Luiszzzor c790d0ee03 fix: update the concurrency middleware to respond with a JSON payload (#798)
update the concurrency middleware to respond with a JSON payload instead
of plain text when the request limit is reached to be compatible with
openai api standard

---------

Co-authored-by: Ludwik <l.czarnota@samsung.com>
2026-05-29 23:59:32 -07:00
Benson Wong 4ca9c478a2 Makefile,internal/server: various release tweaks 2026-05-29 15:27:08 -07:00
Benson Wong 146a9eab24 ui-svelte: update build directory (#801)
Fixes #799
2026-05-29 14:45:05 -07:00
Benson Wong 02e015fa49 Introduce new routing backend (#790)
This is a huge backend change that essentially started with rewriting
the concurrency handling for processes and blew up to a refactor of the
entire application. In short these are the improvements:

**Better state and life cycle management:** 

Life cycle management of processes has always been the trickiest part of
the code. Juggling mutex locks between multiple locations to reduce race
conditions was complex. Too complex for my feeble brain to build a
simple mental model around as llama-swap gained more features. All of
that has been refactored. Most of the locks are gone, replaced with a
single run() that owns all state changes. There is one place to start
from now to understand and extend routing logic.

The improved life cycle management makes it easier to implement more
complex swap optimization strategies in the future like #727.

**Collation of requests:**

llama-swap previously handled requests and swapping in the order they
came in. For example requests for models in this order ABCABC would
result in 5 swaps. Now those requests are handled in this order AABBCC.
The result is less time waiting for swap under a high churn request
queue. This fixes #588 #612.

A possible future enhancement is to support a starvation parameter so
swap can be forced when models have been waiting too long.

**Shared base implementation for groups and swap matrix:** 

During the refactor it became clear that much of the swapping logic was
shared between these two implementations. That is not surprising
considering the swap matrix was added many moons after groups. Now they
share a common base and their specific swap strategies are implemented
into the swapPlanner interface.

Requests for bespoke or specific swapping scenarios is a common theme in
the issues. Now users can implement whatever bespoke and weird swapping
strategy they want in their own fork. Just ask your agent of choice to
implement swapPlanner. I'll still remaining more conservative on what
actually lands in core llama-swap and will continue to evaluate PRs if
the changes is good for everyone or just one specific use case.

**AI / Agentic Disclosure:** 

I paid very close attention to the low level swap concurrency design and
implementation. It's important to keep that essential part reliable,
boring and no surprises. Backwards compatibility was also maintained,
even the one way non-exclusive group model loading behaviour that people
have rightly pointed out be a weird design decision.

With the underlying swap core done the web server, api and UI sitting on
top were largely ported over with Claude Code and Opus 4.7 in multiple
phases. If you're curious I kept the changes in docs/newrouter-todo.md.
I did several passes to make sure things weren't left behind.

However, even frontier LLMs at the time of this PR still make small
decisions that don't make a lot of sense. They get shit wrong all the
time, just in small subtle way.

That said, there's likely to be some new bugs introduced with this
massive refactor. I'm fairly confident that there's no major
architectural flaws that would cause goal seeking agents to make dumb,
ugly code decisions.

For a little while the legacy llama-swap will be available under
cmd/legacy/llama-swap. The plan is to eventually delete that entry point
as well as the proxy package.

On a bit of a personal note, this PR is exciting and a bit sad for me. I
hand wrote much of the original code and this PR ultimately replaces
much of it. While the old code served as a good reference for the agent
to implement the new stuff it still a bit sad to eventually delete it
all.
2026-05-28 21:47:01 -07:00
Cr4xy 63bc266395 Add new power draw column header for rocm-smi monitoring (#788)
# Overview
This patch fixes
https://github.com/mostlygeek/llama-swap/pull/775#issuecomment-4535303706
and removes some unnecessary `break` statements.

## The third variant now also works with power draw:
`
device,Device Name,Device ID,Device Rev,Subsystem ID,GUID,Temperature
(Sensor edge) (C),Temperature (Sensor junction) (C),Temperature (Sensor
memory) (C),Average Graphics Package Power (W),GPU use (%),GPU Memory
Allocated (VRAM%),GPU Memory Read/Write Activity (%),Memory
Activity,Avg. Memory Bandwidth,VRAM Total Memory (B),VRAM Total Used
Memory (B),Card Series,Card Model,Card Vendor,Card SKU,Node ID,GFX
Version
`
<img width="1121" height="315" alt="image"
src="https://github.com/user-attachments/assets/4b908c4d-2401-4dfe-9bac-e7aa770cfb42"
/>

## Old variants:
`
device,Device Name,Device ID,Device Rev,Subsystem ID,GUID,Temperature
(Sensor edge) (C),Temperature (Sensor junction) (C),Temperature (Sensor
memory) (C),Fan speed (level),Fan speed (%),Fan RPM,Current Socket
Graphics Package Power (W),GPU use (%),GPU Memory Allocated (VRAM%),GPU
Memory Read/Write Activity (%),Memory Activity,VRAM Total Memory
(B),VRAM Total Used Memory (B),Card Series,Card Model,Card Vendor,Card
SKU,Node ID,GFX Version
`
<img width="1118" height="308" alt="image"
src="https://github.com/user-attachments/assets/b236e0cd-4505-42e5-b497-cff62c720e3d"
/>

`
device,Device Name,Device ID,Device Rev,Subsystem ID,GUID,Temperature
(Sensor edge) (C),Current Socket Graphics Package Power (W),GPU use
(%),GPU Memory Allocated (VRAM%),Memory Activity,VRAM Total Memory
(B),VRAM Total Used Memory (B),Card Series,Card Model,Card Vendor,Card
SKU,Node ID,GFX Version
`
<img width="1120" height="312" alt="image"
src="https://github.com/user-attachments/assets/1adde1c3-5f35-4db4-ba13-65751ac076e8"
/>
2026-05-25 11:36:16 -07:00
Cr4xy 636b53e70f Improve rocm-smi performance monitoring (#775)
Fix hardcoded indices for rocm-smi.
2026-05-20 17:59:49 -07:00
gatkisson 59cd3b690d Added Windows performance monitoring using nvidia-smi (#773)
updates: #596, #771
2026-05-18 11:02:03 -07:00
Benson Wong 5d1e62d224 Disable auto review feature in coderabbit config 2026-05-18 10:40:21 -07:00
Benson Wong dbb869d019 Increase inactivity thresholds for stale issues
Updated stale issue and close messages to reflect new inactivity thresholds.
2026-05-17 22:52:58 -07:00
Benson Wong 26bb17e57e config.example.yaml: Improve matrix vs groups info
For some use cases groups are simpler to use. Note this in the
documentation that it is still fully supported.
2026-05-17 15:59:25 -07:00
Benson Wong 2982dd3d40 ui-svelte: update link to performance discussion thread 2026-05-17 11:45:56 -07:00
knguyen298 79dc87f881 Add ROCm stats via rocm-smi (#767)
Add ROCm GPU stats support using `rocm-smi`.
2026-05-17 07:58:26 -07:00
krzychdre b2fcc2daa1 ui-svelte: fix cached tokens total counting -1 sentinel (#760)
The backend uses cache_tokens=-1 as a sentinel for endpoints that don't
report cache stats (embeddings, vLLM). The activity table correctly
renders these as "-", but the totals widget summed the sentinels
directly, so each such request subtracted 1 from the displayed total.

- clamp cache_tokens with Math.max(0, ...) when reducing
2026-05-15 14:42:44 -07:00
cdwaage 6a9c4efc8f fix: use --loop instead of -loop for nvidia-smi (driver 540+ compat) (#759) 2026-05-15 13:20:29 -07:00
Benson Wong 0c813e44d1 ui-svelte: package updates 2026-05-14 21:56:04 -07:00
Benson Wong fe71e8a6ea proxy,ui-svelte: improve support for v1/messages and v1/responses (#758)
This improves the support for activity logging from the v1/responses and
v1/messages endpoints.

- add chat endpoint selection to Playground > Chat > Settings
- improve metrics extraction for streaming v1/messages and v1/responses
endpoints (tested with llama-server)

Fixes #742
2026-05-14 21:53:57 -07:00
Benson Wong aac7b8745a ci: set go-version-file in release workflow 2026-05-13 22:12:02 -07:00
Benson Wong 4e606feff0 ci: fix workflow bugs in release and go-ci
- release.yml: merge orphaned `uses:` into the Checkout step
- go-ci.yml: skip simple-responder build when restored from cache
2026-05-13 21:48:27 -07:00
Benson Wong a4b91e08cf Changes and fixes before the release (docs/small tweaks) (#750)
- update README.md with new docker instructions
- update docs/configuration.md
- update .github/workflows to have pinned action versions
- gofmt events package
- fix small bugs in CI scripts
- reduce config options for internal/perf/monitor and config. A ring buffer is used to keep 1hr of entries at max 5s granularity. For long term stats use prometheus monitoring on /metrics

Fixes #744
2026-05-13 21:18:19 -07:00
David Soušek 3e3646f9f9 perf: ignore LACT devices reporting zero VRAM (#753)
Ignore LACT devices that report zero total VRAM.

Some virtual GPUs on headless VMs report `MemTotalMB == 0` through LACT,
which makes them appear in performance monitoring despite not providing
useful memory data. Skip those entries so only usable GPU devices are
reported.

This makes performance monitoring cleaner on headless VMs with virtual
GPUs that report zero VRAM.

Co-authored-by: David Soušek <david.sousek@intelogy.co.uk>
2026-05-13 10:03:54 -07:00
rhtenhove a01afe261b ci: use manifest-aware cleanup action for multi-arch :cpu (#751)
actions/delete-package-versions can't see OCI manifest lists. When the
cpu build pushes a multi-arch image, the registry gets a tagged index
plus one untagged per-platform manifest per arch. The cleanup step with
`delete-only-untagged-versions: true` then deletes the per-platform
children, leaving the index dangling — `docker pull
ghcr.io/mostlygeek/llama-swap:cpu` 404s on the referenced sha.

Swap to dataaxiom/ghcr-cleanup-action, which inspects tagged manifest
lists first and excludes their children from deletion. Single-arch
backends behave the same as before.

Fix #746
2026-05-12 18:04:46 -07:00
rhtenhove 174e8562aa Multi arch cpu (#746)
Encountered a similar problem as in
https://github.com/mostlygeek/llama-swap/issues/709 but in my case I
only needed the :cpu version.

So decided to add the github action to build arm64 combined with the
amd64 version on the same :cpu tag. Already tested it from this fork:
ghcr.io/rhtenhove/llama-swap:cpu and it works perfectly fine.

Adding GPU support is a whole other beast, needing quite a bit more work
and isn't something I can test.
2026-05-11 21:03:48 -07:00
Abdulazez A. 085b54bc88 proxy: fix data race in /running endpoint and typo in error message (#748)
## Problem

The `/running` endpoint in `listRunningProcessesHandler` reads
`process.state` directly without holding `stateMutex`. Meanwhile,
`swapState()` writes to `process.state` while holding the write lock.
This is a data race flagged by the Go race detector.

Also fixes a minor typo: "processes was in state" → "process was in
state".

## Fix

- `proxymanager.go`: Replace `process.state` with
`process.CurrentState()` which acquires `stateMutex.RLock()` before
reading.
- `process.go`: Fix typo in error message.

## Verification

- `gofmt -l` — clean
- `go test -run "TestProcessGroup_|TestProxyManager_" ./proxy/` — all
pass
- `go test ./proxy/config/... ./proxy/cache/...
./proxy/configwatcher/...` — all pass
2026-05-11 12:49:18 -07:00
bankjaneo 2be3416baa ui: add auto theme switch mode based on system theme (#741)
Add system theme detection with automatic switching when OS theme
changes.

- Add ThemeMode type with "light", "dark", and "system" options
- Add system theme listener using matchMedia API
- Update theme toggle to cycle through System → Light → Dark
- Add combined sun/moon icon for system theme mode
- Migrate existing theme preferences to new format
2026-05-09 20:22:18 -07:00
Benson Wong 7e3e94a08a proxy,ui: add performance monitoring with Prometheus metrics (#743)
Add a comprehensive performance monitoring system that collects CPU, memory, swap, load average, network IO, and GPU stats. Provides both a REST API for the UI and a Prometheus /metrics endpoint.

Backend changes:
- New internal/perf package with configurable interval-based stats collection
- GPU monitoring via LACT (Unix socket) and nvidia-smi fallback on Linux
- Ring buffer (internal/ring) for time-series stat storage
- Prometheus /metrics endpoint with all system and GPU metrics
- Moved LogMonitor to internal/logmon package
- New PerformanceConfig for hot-reloadable monitoring settings
- REST /api/performance endpoint replacing SSE streaming

UI changes:
- New Performance page with real-time charts for CPU, memory, GPU, and network
- Reusable PerformanceChart component
- LLAMA_SWAP_URL environment variable support
- Improved capture dialog display

Other:
- Example Grafana dashboard for Prometheus metrics
- monitor-test standalone binary
- Config schema and example updates

fixes #596
2026-05-09 13:29:22 -07:00
Wim Vander Schelden e261745c66 proxy: add versionless API endpoint (#733)
Add versionless endpoints under v/ to support upstream peers that 
do not use the v1/ prefix.

Fixes #728.
2026-05-03 13:47:38 -07:00
Benson Wong 11b7913287 llama-swap.go: remove debounce, replace fmt.Printlns (#731)
small fixes to clean up the main(): 

- remove the debounced config reload 
- replace fmt.Println with a proxy.LogMonitor for consistency
2026-05-02 16:28:53 -07:00
155 changed files with 19447 additions and 1137 deletions
+1 -1
View File
@@ -13,7 +13,7 @@ reviews:
docstrings:
enabled: false
auto_review:
enabled: true
enabled: false
drafts: false
chat:
auto_reply: true
+5 -5
View File
@@ -11,13 +11,13 @@ jobs:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f #v10.2.0
with:
days-before-issue-stale: 14
days-before-issue-close: 14
days-before-issue-stale: 30
days-before-issue-close: 30
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
stale-issue-message: "This issue is stale because it has been open without activity for 30 days. Please remove the stale label if this was an error."
close-issue-message: "This issue was closed because it has been inactive for 30 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
+2 -2
View File
@@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Validate JSON Schema
run: |
@@ -45,7 +45,7 @@ jobs:
echo "✓ config-schema.json is valid"
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 #v6.2.0
with:
python-version: "3.x"
+32 -8
View File
@@ -9,6 +9,11 @@ on:
# Allows manual triggering of the workflow
workflow_dispatch:
inputs:
dryrun:
description: "Run cleanup step in dry-run mode (log what would be deleted, delete nothing)"
type: boolean
default: false
# Run on workflow file changes (without pushing)
push:
@@ -33,7 +38,7 @@ jobs:
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Free up disk space
if: matrix.platform == 'rocm'
@@ -48,8 +53,18 @@ jobs:
echo "After cleanup:"
df -h
# QEMU enables arm64 cross-builds on the amd64 GitHub runner.
# Currently only the cpu backend goes multi-arch; the action is a
# no-op for amd64-only builds, so leaving it on for every matrix
# entry keeps the workflow simple.
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 #v4.1.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -60,14 +75,23 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
# actions/delete-package-versions can't see manifest lists: pushing
# a multi-arch image with `docker buildx --push` creates a tagged OCI
# index plus one untagged per-platform manifest per arch, and
# `delete-only-untagged-versions: true` then nukes the per-platform
# children, leaving the index dangling — `docker pull :cpu` 404s on
# the referenced digest. dataaxiom/ghcr-cleanup-action walks tagged
# manifest lists and excludes their children from deletion.
delete-untagged-containers:
needs: build-and-push
# Skip on forks — the delete API requires package-admin on the
# upstream account and would otherwise red-x every fork CI run.
if: github.repository == 'mostlygeek/llama-swap'
runs-on: ubuntu-latest
steps:
- uses: actions/delete-package-versions@v5
- uses: dataaxiom/ghcr-cleanup-action@cd0cdb900b5dbf3a6f2cc869f0dbb0b8211f50c4 # v1.0.16
with:
package-name: 'llama-swap'
package-type: 'container'
delete-only-untagged-versions: 'true'
token: ${{ secrets.GITHUB_TOKEN }}
package: llama-swap
delete-untagged: true
dry-run: ${{ inputs.dryrun || false }}
+6 -6
View File
@@ -31,17 +31,17 @@ jobs:
run-tests:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with:
go-version: '1.23'
go-version-file: go.mod
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
@@ -56,11 +56,11 @@ jobs:
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- name: Test all
shell: bash
run: make test-all
run: make test-all
+7 -6
View File
@@ -30,37 +30,38 @@ jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with:
go-version-file: go.mod
# Only run in this linux based runner
- name: Check Formatting
run: |
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
gofmt -l . | grep -v 'event/.*_test.go'
if [ "$(gofmt -l . | wc -l)" -gt 0 ]; then
gofmt -l .
exit 1
fi
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
run: make simple-responder
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
+9 -9
View File
@@ -20,24 +20,24 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
with:
fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with:
go-version-file: go.mod
- name: Set up Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
with:
node-version: "24"
- name: Install dependencies and build UI
- name: Build UI
run: |
cd ui-svelte
npm ci
npm run build
make ui
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
uses: goreleaser/goreleaser-action@1a80836c5c9d9e5755a25cb59ec6f45a3b5f41a8 #7.2.1
with:
# either 'goreleaser' (default) or 'goreleaser-pro'
distribution: goreleaser
@@ -61,7 +61,7 @@ jobs:
fi
- name: "Trigger tap repository update"
uses: peter-evans/repository-dispatch@v2
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 #4.0.1
with:
token: ${{ secrets.TAP_REPO_PAT }}
repository: mostlygeek/homebrew-llama-swap
+2 -2
View File
@@ -20,10 +20,10 @@ jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Set up Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
with:
node-version: '24'
cache: 'npm'
+3 -3
View File
@@ -75,7 +75,7 @@ jobs:
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Free up disk space
run: |
@@ -94,11 +94,11 @@ jobs:
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
- name: Set up Docker Buildx
if: ${{ !env.ACT }}
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Log in to GitHub Container Registry
if: ${{ !env.ACT }}
uses: docker/login-action@v3
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 #v4.1.0
with:
registry: ghcr.io
username: ${{ github.actor }}
+3
View File
@@ -5,3 +5,6 @@ dist/
.vscode
.DS_Store
.dev/
# UI build output; placeholder.txt is kept so the go:embed succeeds.
internal/server/ui_dist/*
+2 -1
View File
@@ -21,7 +21,8 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
- Run `gofmt -w <file>` before committing to fix any formatting
- Build go binaries into the ./build/ subdirectory
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests.
- Use `make test-ui` after making changes to the UI in ui-svelte/
+5 -4
View File
@@ -25,15 +25,15 @@ proxy/ui_dist/placeholder.txt:
# use cached test results while developing
test-dev: proxy/ui_dist/placeholder.txt
go test -short ./proxy/...
staticcheck ./proxy/... || true
go test -short ./proxy/... ./internal/...
staticcheck ./proxy/... ./internal/... || true
test: proxy/ui_dist/placeholder.txt
go test -short -count=1 ./proxy/...
go test -short -count=1 ./proxy/... ./internal/...
# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt
go test -race -count=1 ./proxy/...
go test -race -count=1 ./proxy/... ./internal/...
ui/node_modules:
cd ui-svelte && npm install
@@ -41,6 +41,7 @@ ui/node_modules:
# build react UI
ui: ui/node_modules
cd ui-svelte && npm run build
touch internal/server/ui_dist/placeholder.txt
# Build OSX binary
mac: ui
+22 -12
View File
@@ -52,12 +52,14 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `GET /logs/stream/upstream` streams upstream process logs only.
- `GET /logs/stream/{model_id}` streams logs for one model (including IDs with slashes, like `author/model`).
- `/health` - just returns "OK"
- `/metrics` - system and GPU metrics for prometheus
- ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable
- Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
- Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
- Apply filters to requests to control inference with `stripParams`, `setParams` and `setParamsByID`
### Web UI
@@ -93,8 +95,24 @@ llama-swap can be installed in multiple ways
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
Two types of container images are built nightly for llama-swap:
1. A unified container with llama-server, ik-llama-server, stable-diffusion.cpp, whisper.cpp and llama-swap built from source. This is only available for cuda and vulkan but has more capabilities. This one is recommended for use.
2. A legacy image that is based on llama.cpp's images and llama-swap copied into the container. Use this one if you prefer to stay close to llama.cpp's container images.
#### Unified container (Recommended)
```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:unified-cuda
# run with a custom configuration and models directory
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/etc/llama-swap/config/config.yaml \
ghcr.io/mostlygeek/llama-swap:unified-cuda
```
#### Legacy container
```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
@@ -104,14 +122,6 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
# configuration hot reload supported with a
# directory volume mount
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
-v /path/to/config:/config \
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
```
<details>
@@ -267,6 +277,6 @@ For Python based inference servers like vllm or tabbyAPI it is recommended to ru
## Star History
> [!NOTE]
> ⭐️ Star this project to help others discover it!
> Thank you to everyone who has given this project a ⭐️!
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
+306
View File
@@ -0,0 +1,306 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/tidwall/gjson"
)
var loremWords = strings.Fields(
"Lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor " +
"incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud " +
"exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat Duis aute " +
"irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla " +
"pariatur Excepteur sint occaecat cupidatat non proident sunt in culpa qui officia " +
"deserunt mollit anim id est laborum Sed ut perspiciatis unde omnis iste natus error " +
"sit voluptatem accusantium doloremque laudantium totam rem aperiam eaque ipsa quae " +
"ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo " +
"Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit",
)
var (
flagListen = flag.String("listen", "localhost:9898", "listen address")
flagTokens = flag.Int("tokens", 1000, "number of tokens to return")
flagTPS = flag.Float64("tps", 75, "tokens per second")
flagLoad = flag.String("load", "0s", "simulated load duration (e.g. 2s, 500ms)")
)
type chunkDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}
type chunkChoice struct {
Index int `json:"index"`
Delta chunkDelta `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type chatChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []chunkChoice `json:"choices"`
}
type completionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type completionChoice struct {
Index int `json:"index"`
Message completionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type completionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type chatCompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []completionChoice `json:"choices"`
Usage completionUsage `json:"usage"`
}
func loremText(n int) string {
words := make([]string, n)
for i := range words {
words[i] = loremWords[i%len(loremWords)]
}
return strings.Join(words, " ")
}
func sendChunk(w http.ResponseWriter, content string, finishReason *string) error {
chunk := chatChunk{
ID: "chatcmpl-fake",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "fake-model",
Choices: []chunkChoice{
{
Index: 0,
Delta: chunkDelta{Content: content},
FinishReason: finishReason,
},
},
}
data, err := json.Marshal(chunk)
if err != nil {
return err
}
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
return err
}
// startLoading runs the countdown log and closes ready when loadDur elapses.
// If loadDur is zero, ready is closed immediately.
func startLoading(loadDur time.Duration) <-chan struct{} {
ready := make(chan struct{})
if loadDur == 0 {
close(ready)
return ready
}
go func() {
deadline := time.Now().Add(loadDur)
log.Printf("loading... %s remaining", loadDur.Round(time.Second))
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
timer := time.NewTimer(loadDur)
for {
select {
case <-timer.C:
close(ready)
log.Printf("ready")
return
case <-ticker.C:
if rem := time.Until(deadline).Round(time.Second); rem > 0 {
log.Printf("loading... %s remaining", rem)
}
}
}
}()
return ready
}
func healthHandler(ready <-chan struct{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
select {
case <-ready:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusServiceUnavailable)
}
}
}
func chatHandler(ready <-chan struct{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusBadRequest)
return
}
streaming := gjson.GetBytes(body, "stream").Bool()
ctx := r.Context()
select {
case <-ready:
case <-ctx.Done():
return
}
tokens := *flagTokens
tps := *flagTPS
if tps <= 0 {
tps = 1
}
if !streaming {
delay := time.Duration(float64(tokens) / tps * float64(time.Second))
select {
case <-time.After(delay):
case <-ctx.Done():
return
}
text := loremText(tokens)
resp := chatCompletion{
ID: "chatcmpl-fake",
Object: "chat.completion",
Created: time.Now().Unix(),
Model: "fake-model",
Choices: []completionChoice{
{
Index: 0,
Message: completionMessage{Role: "assistant", Content: text},
FinishReason: "stop",
},
},
Usage: completionUsage{
PromptTokens: 0,
CompletionTokens: tokens,
TotalTokens: tokens,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Send role delta first
first := chatChunk{
ID: "chatcmpl-fake",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "fake-model",
Choices: []chunkChoice{
{Index: 0, Delta: chunkDelta{Role: "assistant"}},
},
}
if data, err := json.Marshal(first); err == nil {
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
interval := time.Duration(float64(time.Second) / tps)
ticker := time.NewTicker(interval)
defer ticker.Stop()
stop := "stop"
for i := 0; i < tokens; i++ {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
word := loremWords[i%len(loremWords)]
if i < tokens-1 {
if err := sendChunk(w, word+" ", nil); err != nil {
return
}
} else {
if err := sendChunk(w, word, &stop); err != nil {
return
}
}
flusher.Flush()
}
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
}
func main() {
flag.Parse()
loadDur, err := time.ParseDuration(*flagLoad)
if err != nil {
log.Fatalf("invalid -load value %q: %v", *flagLoad, err)
}
ready := startLoading(loadDur)
mux := http.NewServeMux()
mux.HandleFunc("/health", healthHandler(ready))
mux.HandleFunc("/v1/chat/completions", chatHandler(ready))
srv := &http.Server{
Addr: *flagListen,
Handler: mux,
}
go func() {
log.Printf("listening on %s (tokens=%d tps=%.1f load=%s)",
*flagListen, *flagTokens, *flagTPS, loadDur)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("server error: %v", err)
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Printf("shutdown error: %v", err)
}
}
+249
View File
@@ -0,0 +1,249 @@
package main
import (
"context"
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/perf"
"github.com/mostlygeek/llama-swap/internal/watcher"
"github.com/mostlygeek/llama-swap/proxy"
)
var (
version string = "0"
commit string = "abcd1234"
date string = "unknown"
)
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", "", "listen ip/port")
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
keyFile := flag.String("tls-key-file", "", "TLS key file")
showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
mainLogger := logmon.New()
flag.Parse() // Parse the command-line flags
if *showVersion {
fmt.Printf("version: %s (%s), built at %s", version, commit, date)
os.Exit(0)
}
conf, err := config.LoadConfig(*configPath)
if err != nil {
mainLogger.Errorf("Error loading config: %v", err)
os.Exit(1)
}
if len(conf.Profiles) > 0 {
mainLogger.Warn("Profile functionality has been removed in favor of Groups. See the README for more information.")
}
switch strings.ToLower(strings.TrimSpace(conf.LogLevel)) {
case "debug":
mainLogger.SetLogLevel(logmon.LevelDebug)
case "info":
mainLogger.SetLogLevel(logmon.LevelInfo)
case "warn":
mainLogger.SetLogLevel(logmon.LevelWarn)
case "error":
mainLogger.SetLogLevel(logmon.LevelError)
default:
mainLogger.SetLogLevel(logmon.LevelInfo)
}
mainLogger.Debugf("PID: %d", os.Getpid())
if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode)
} else {
gin.SetMode(gin.ReleaseMode)
}
// Validate TLS flags.
var useTLS = (*certFile != "" && *keyFile != "")
if (*certFile != "" && *keyFile == "") ||
(*certFile == "" && *keyFile != "") {
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
os.Exit(1)
}
// Set default ports.
if *listenStr == "" {
defaultPort := ":8080"
if useTLS {
defaultPort = ":8443"
}
listenStr = &defaultPort
}
var mon *perf.Monitor
if !conf.Performance.Disabled {
mon, err = perf.New(conf.Performance, mainLogger)
if err != nil {
mainLogger.Errorf("failed to create monitor: %s", err.Error())
os.Exit(1)
}
mon.Start()
} else {
mainLogger.Info("performance monitoring is disabled")
}
// Setup channels for server management
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
// Context that bounds the lifetime of background watcher goroutines.
watcherCtx, watcherCancel := context.WithCancel(context.Background())
// Create server with initial handlergit
srv := &http.Server{
Addr: *listenStr,
}
// Support for watching config and reloading when it changes
reloading := false
var reloadMutex sync.Mutex
reloadProxyManager := func() {
reloadMutex.Lock()
if reloading {
reloadMutex.Unlock()
return
}
reloading = true
reloadMutex.Unlock()
defer func() {
reloadMutex.Lock()
reloading = false
reloadMutex.Unlock()
}()
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
mainLogger.Info("Reloading Configuration")
conf, err = config.LoadConfig(*configPath)
if err != nil {
mainLogger.Warnf("Unable to reload configuration: %v", err)
return
}
mainLogger.Debug("Configuration Changed")
currentPM.Shutdown()
if mon != nil {
mon.UpdateConfig(conf.Performance)
}
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
newPM.SetPerfMonitor(mon)
srv.Handler = newPM
mainLogger.Debug("Configuration Reloaded")
// wait a few seconds and tell any UI to reload
time.AfterFunc(3*time.Second, func() {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateEnd,
})
})
} else {
conf, err = config.LoadConfig(*configPath)
if err != nil {
mainLogger.Errorf("Unable to load configuration: %v", err)
os.Exit(1)
}
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
newPM.SetPerfMonitor(mon)
srv.Handler = newPM
}
}
// load the initial proxy manager
reloadProxyManager()
if *watchConfig {
go func() {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
mainLogger.Errorf("watch-config unable to determine absolute path for watching config file: %v", err)
return
}
mainLogger.Info("Watching configuration for changes (poll-based, 2s interval)")
(&configwatcher.Watcher{
Path: absConfigPath,
Interval: configwatcher.DefaultInterval,
OnChange: func() {
reloadProxyManager()
},
}).Run(watcherCtx)
}()
}
// Signal handling
go func() {
for {
sig := <-sigChan
switch sig {
case syscall.SIGHUP:
mainLogger.Debug("Received SIGHUP")
reloadProxyManager()
case syscall.SIGINT, syscall.SIGTERM:
mainLogger.Debugf("Received signal %v, shutting down...", sig)
if mon != nil {
mon.Stop()
}
watcherCancel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
pm.Shutdown()
} else {
mainLogger.Error("srv.Handler is not of type *proxy.ProxyManager")
}
if err := srv.Shutdown(ctx); err != nil {
mainLogger.Errorf("Server shutdown: %v", err)
}
close(exitChan)
return
default:
// do nothing on other signals
}
}
}()
// Start server
go func() {
var err error
if useTLS {
mainLogger.Infof("llama-swap listening with TLS on https://%s", *listenStr)
err = srv.ListenAndServeTLS(*certFile, *keyFile)
} else {
mainLogger.Infof("llama-swap listening on http://%s", *listenStr)
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
mainLogger.Errorf("Fatal server error: %v", err)
os.Exit(1)
}
}()
// Wait for exit signal
<-exitChan
}
+92
View File
@@ -0,0 +1,92 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/perf"
)
func printSysStat(s perf.SysStat) {
cores := make([]string, len(s.CpuUtilPerCore))
for i, v := range s.CpuUtilPerCore {
cores[i] = fmt.Sprintf("%.1f%%", v)
}
fmt.Printf("[SYS %s]\n", s.Timestamp.Format("15:04:05"))
fmt.Printf(" CPU: %s\n", strings.Join(cores, " "))
fmt.Printf(" Mem: %d MB used / %d MB total (%d MB free)\n", s.MemUsedMB, s.MemTotalMB, s.MemFreeMB)
fmt.Printf(" Swap: %d MB used / %d MB total\n", s.SwapUsedMB, s.SwapTotalMB)
fmt.Printf(" Load: %.2f %.2f %.2f (1m 5m 15m)\n", s.LoadAvg1, s.LoadAvg5, s.LoadAvg15)
}
func printGpuStats(gpus []perf.GpuStat) {
for _, g := range gpus {
fmt.Printf("[GPU %d %s]\n", g.ID, g.Name)
fmt.Printf(" Util: GPU %.1f%% Mem %.1f%%\n", g.GpuUtilPct, g.MemUtilPct)
fmt.Printf(" Mem: %d MB used / %d MB total\n", g.MemUsedMB, g.MemTotalMB)
fmt.Printf(" Temp: %d°C Fan: %.1f%% Power: %.1f W\n", g.TempC, g.FanSpeedPct, g.PowerDrawW)
}
}
func main() {
stream := flag.Bool("stream", false, "stream stats")
interval := flag.Duration("t", time.Second, "polling interval (clamped to 1s1h)")
flag.Parse()
every := *interval
if every < time.Second {
every = time.Second
} else if every > time.Hour {
every = time.Hour
}
l := logmon.New()
l.SetLogLevel(logmon.LevelDebug)
s, err := perf.ReadSysStats()
if err != nil && err != perf.ErrNotImplemented {
fmt.Println("Sys Error:", err)
return
}
printSysStat(s)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
gpuCh, err := perf.GetGpuStats(ctx, every, l)
if err != nil && !errors.Is(err, perf.ErrNotImplemented) && !errors.Is(err, perf.ErrNoGpuTool) {
fmt.Println("GPU Init Error:", err)
return
}
if gpuCh != nil {
select {
case g := <-gpuCh:
printGpuStats(g)
case <-ctx.Done():
fmt.Println("GPU: timed out waiting for stats")
}
}
if *stream {
m, _ := perf.New(config.PerformanceConfig{Every: every}, l)
m.Start()
defer m.Stop()
sysCh, gpuCh, unsub := m.Subscribe()
defer unsub()
for {
select {
case s := <-sysCh:
printSysStat(s)
case g := <-gpuCh:
printGpuStats(g)
}
}
}
}
+96
View File
@@ -0,0 +1,96 @@
package main
import (
"flag"
"fmt"
"os"
"sync"
"time"
tea "github.com/charmbracelet/bubbletea"
)
func main() {
prompt := flag.String("prompt", "Write a few sentences about the history of computing.", "user message sent to each model")
maxTokens := flag.Int("max-tokens", 256, "max_tokens per request")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <base-url> <model> [model...]\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Example: %s -max-tokens 400 http://localhost:8080 A B C D\n\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
args := flag.Args()
if len(args) < 2 {
flag.Usage()
os.Exit(1)
}
baseURL := args[0]
models := args[1:]
m := newModel(models)
prog := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
// Chain of triggers ensures requests are sent in the order provided.
triggers := make([]chan struct{}, len(models))
for i := range triggers {
triggers[i] = make(chan struct{}, 1)
}
triggers[0] <- struct{}{}
var wg sync.WaitGroup
start := time.Now()
for i, name := range models {
wg.Add(1)
go func(idx int, mdl string) {
defer wg.Done()
<-triggers[idx]
reqStart := time.Now()
prog.Send(statusMsg{idx: idx, status: statusStreaming})
if idx+1 < len(triggers) {
triggers[idx+1] <- struct{}{}
}
err := sendRequest(baseURL, mdl, *prompt, *maxTokens, idx, func(i int, text string) {
prog.Send(deltaMsg{idx: i, text: text})
})
elapsed := time.Since(reqStart)
if err != nil {
prog.Send(statusMsg{idx: idx, status: statusError, elapsed: elapsed, err: err})
} else {
prog.Send(statusMsg{idx: idx, status: statusDone, elapsed: elapsed})
}
}(i, name)
}
if _, err := prog.Run(); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
wg.Wait()
printSummary(m, start)
}
func printSummary(m *model, start time.Time) {
fmt.Println("Summary:")
for _, p := range m.panels {
switch p.status {
case statusError:
fmt.Printf(" [%d] %-20s ERROR elapsed=%s err=%v\n",
p.idx, p.model, p.elapsed.Round(time.Millisecond), p.err)
case statusDone:
fmt.Printf(" [%d] %-20s done elapsed=%s\n",
p.idx, p.model, p.elapsed.Round(time.Millisecond))
default:
fmt.Printf(" [%d] %-20s %s\n", p.idx, p.model, p.status)
}
}
fmt.Printf("all done in %s\n", time.Since(start).Round(time.Millisecond))
}
+88
View File
@@ -0,0 +1,88 @@
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
// deltaSink receives streamed text fragments for a given model panel.
type deltaSink func(idx int, text string)
type streamDelta struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
}
type streamChoice struct {
Delta streamDelta `json:"delta"`
}
type streamChunk struct {
Choices []streamChoice `json:"choices"`
}
// sendRequest streams a chat completion and forwards each content/reasoning
// delta to sink. Reasoning and assistant content are emitted into the same
// stream so they render together.
func sendRequest(baseURL, model, prompt string, maxTokens, idx int, sink deltaSink) error {
payload := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"max_tokens": maxTokens,
"stream": true,
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
resp, err := http.Post(baseURL+"/v1/chat/completions", "application/json", bytes.NewReader(body))
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
if data == "[DONE]" {
break
}
continue
}
var chunk streamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, c := range chunk.Choices {
if c.Delta.ReasoningContent != "" {
sink(idx, c.Delta.ReasoningContent)
}
if c.Delta.Content != "" {
sink(idx, c.Delta.Content)
}
}
}
return scanner.Err()
}
+343
View File
@@ -0,0 +1,343 @@
package main
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type panelStatus int
const (
statusWaiting panelStatus = iota
statusStreaming
statusDone
statusError
)
func (s panelStatus) String() string {
switch s {
case statusStreaming:
return "streaming"
case statusDone:
return "done"
case statusError:
return "error"
default:
return "waiting"
}
}
// deltaMsg appends streamed text to a panel.
type deltaMsg struct {
idx int
text string
}
// statusMsg updates a panel's lifecycle state.
type statusMsg struct {
idx int
status panelStatus
elapsed time.Duration
err error
}
type panel struct {
idx int
model string
color lipgloss.Color
status panelStatus
buf strings.Builder
elapsed time.Duration
err error
}
const (
minPanelWidth = 28
maxCols = 3
panelHeight = 9 // total box height including border + header
)
type model struct {
panels []*panel
focused int
vp viewport.Model
width int
height int
cols int
pw int // inner panel content width
ready bool
}
func newModel(models []string) *model {
// Assign a stable color per unique model name (by first appearance).
colorOf := map[string]lipgloss.Color{}
panels := make([]*panel, len(models))
for i, m := range models {
c, ok := colorOf[m]
if !ok {
c = modelPalette[len(colorOf)%len(modelPalette)]
colorOf[m] = c
}
panels[i] = &panel{idx: i, model: m, color: c, status: statusWaiting}
}
return &model{panels: panels, focused: 0}
}
func (m *model) Init() tea.Cmd { return nil }
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.relayout()
m.refreshViewport(true)
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "q", "ctrl+c", "esc":
return m, tea.Quit
case "tab", "right", "l":
m.setFocus(m.focused + 1)
return m, nil
case "shift+tab", "left", "h":
m.setFocus(m.focused - 1)
return m, nil
}
var cmd tea.Cmd
m.vp, cmd = m.vp.Update(msg)
return m, cmd
case tea.MouseMsg:
if msg.Action == tea.MouseActionPress && msg.Button == tea.MouseButtonLeft {
if idx, ok := m.panelAt(msg.X, msg.Y); ok {
m.setFocus(idx)
}
return m, nil
}
var cmd tea.Cmd
m.vp, cmd = m.vp.Update(msg)
return m, cmd
case deltaMsg:
p := m.panels[msg.idx]
p.buf.WriteString(msg.text)
if msg.idx == m.focused {
atBottom := m.vp.AtBottom()
m.refreshViewport(false)
if atBottom {
m.vp.GotoBottom()
}
}
return m, nil
case statusMsg:
p := m.panels[msg.idx]
p.status = msg.status
p.elapsed = msg.elapsed
p.err = msg.err
if msg.err != nil {
errTxt := lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Render("\n" + msg.err.Error())
p.buf.WriteString(errTxt)
if msg.idx == m.focused {
m.refreshViewport(false)
m.vp.GotoBottom()
}
}
return m, nil
}
return m, nil
}
func (m *model) setFocus(idx int) {
if len(m.panels) == 0 {
return
}
if idx < 0 {
idx = len(m.panels) - 1
}
if idx >= len(m.panels) {
idx = 0
}
if idx == m.focused {
return
}
m.focused = idx
m.refreshViewport(true)
}
// relayout recomputes grid columns and panel/viewport dimensions.
func (m *model) relayout() {
if m.width < minPanelWidth+4 {
m.cols = 1
} else {
m.cols = m.width / (minPanelWidth + 2)
if m.cols > maxCols {
m.cols = maxCols
}
if m.cols > len(m.panels) {
m.cols = len(m.panels)
}
if m.cols < 1 {
m.cols = 1
}
}
// inner content width: total width / cols, minus borders+padding (4) and gap.
boxOuter := m.width/m.cols - 1
m.pw = boxOuter - 4
if m.pw < 8 {
m.pw = 8
}
m.vp = viewport.New(m.pw, panelHeight-2)
m.ready = true
}
func (m *model) refreshViewport(reset bool) {
if !m.ready || len(m.panels) == 0 {
return
}
content := lipgloss.NewStyle().Width(m.pw).Render(m.panels[m.focused].buf.String())
m.vp.SetContent(content)
if reset {
m.vp.GotoBottom()
}
}
// panelAt maps screen coordinates to a panel index based on the grid layout.
func (m *model) panelAt(x, y int) (int, bool) {
if m.cols == 0 {
return 0, false
}
boxOuterW := m.width/m.cols + 1
col := x / boxOuterW
row := y / panelHeight
idx := row*m.cols + col
if col < m.cols && idx >= 0 && idx < len(m.panels) {
return idx, true
}
return 0, false
}
func (m *model) View() string {
if !m.ready {
return "loading..."
}
rows := []string{}
var current []string
for i, p := range m.panels {
current = append(current, m.renderPanel(p, i == m.focused))
if len(current) == m.cols {
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
current = nil
}
}
if len(current) > 0 {
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
}
grid := lipgloss.JoinVertical(lipgloss.Left, rows...)
footer := lipgloss.NewStyle().Faint(true).Render(
"tab/click: focus panel • wheel/↑↓/pgup/pgdn: scroll focused • q: quit")
return grid + "\n" + footer
}
// modelPalette gives each panel a distinct, readable color for its name.
var modelPalette = []lipgloss.Color{
"39", // blue
"213", // magenta
"214", // orange
"45", // cyan
"141", // purple
"203", // salmon
"82", // lime
"227", // light yellow
}
func statusColor(s panelStatus) lipgloss.Color {
switch s {
case statusStreaming:
return lipgloss.Color("220") // yellow - active
case statusDone:
return lipgloss.Color("42") // green - success
case statusError:
return lipgloss.Color("196") // red - error
default:
return lipgloss.Color("244") // gray - waiting
}
}
func (m *model) renderPanel(p *panel, focused bool) string {
border := lipgloss.RoundedBorder()
if focused {
border = lipgloss.DoubleBorder()
}
style := lipgloss.NewStyle().
Border(border).
BorderForeground(lipgloss.Color("240"))
statusTxt := p.status.String()
if p.elapsed > 0 {
statusTxt += " " + p.elapsed.Round(time.Millisecond).String()
}
// Header: model name (left, model color) + status/timer (right, status color).
name := fmt.Sprintf("[%d] %s", p.idx, p.model)
gap := m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
if gap < 1 {
name = truncate(name, m.pw-lipgloss.Width(statusTxt)-1)
gap = m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
}
if gap < 1 {
gap = 1
}
header := lipgloss.NewStyle().Bold(true).Foreground(p.color).Render(name) +
strings.Repeat(" ", gap) +
lipgloss.NewStyle().Foreground(statusColor(p.status)).Render(statusTxt)
var bodyLines string
if focused {
bodyLines = m.vp.View()
} else {
bodyLines = tailLines(p.buf.String(), m.pw, panelHeight-2)
}
content := lipgloss.JoinVertical(lipgloss.Left, header, bodyLines)
return style.Width(m.pw).Height(panelHeight - 2).Render(content)
}
func truncate(s string, w int) string {
if w <= 0 {
return ""
}
if lipgloss.Width(s) <= w {
return s
}
r := []rune(s)
if len(r) > w {
r = r[:w]
}
return string(r)
}
// tailLines wraps text to width w and returns the last n lines.
func tailLines(s string, w, n int) string {
wrapped := lipgloss.NewStyle().Width(w).Render(s)
lines := strings.Split(wrapped, "\n")
if len(lines) > n {
lines = lines[len(lines)-n:]
}
for len(lines) < n {
lines = append(lines, "")
}
return strings.Join(lines, "\n")
}
+20 -1
View File
@@ -142,6 +142,25 @@
"default": 5,
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
},
"performance": {
"type": "object",
"properties": {
"disabled": {
"type": "boolean",
"default": false,
"description": "Disable system performance monitoring."
},
"every": {
"type": "string",
"pattern": "^[-+]?(\\d+(\\.\\d+)?(ns|us|ms|s|m|h))+$",
"default": "15s",
"description": "Delay between polling for new performance statistics. Minimum duration is 1s. Lower values use more RAM as stats are kept in memory."
}
},
"additionalProperties": false,
"default": {},
"description": "Configuration for CPU, RAM and GPU monitoring statistics."
},
"startPort": {
"type": "integer",
"default": 5800,
@@ -517,4 +536,4 @@
}
}
]
}
}
+27 -6
View File
@@ -55,6 +55,18 @@ metricsMaxInMemory: 1000
# - set to 0 to disable
captureBuffer: 15
# performance: configuration for system monitoring statistics
# - timing values are duration strings like 1s, 1h30m, 90m, 2h10s, etc.
performance:
# disabled: boolean
# - default: false
disabled: false
# every: delay between polling for new performance statistics
# - default: 5s
# - minimum duration 5s
every: 15s
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -96,8 +108,7 @@ globalTTL: 0
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
/path/to/llama-server/llama-server-ec9e0301 --port ${PORT}
"default_ctx": 4096
@@ -257,7 +268,8 @@ models:
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp},
context=${default_ctx}"
a_list:
- 1
@@ -335,11 +347,20 @@ models:
# matrix: run concurrent models with a solver-based swap DSL
# =============================================================================
#
# Note:
# A config must use either a matrix or legacy groups, not both. A configuration error
# will occur if both are defined. Configuration examples for legacy Groups can be found:
# Matrix or Groups?
#
# Groups are available and fully supported. The syntax may be easier to use
# for simple use cases.
#
# Documentation can be found here:
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
#
# A config can only use a matrix (recommended) or groups. A configuration error
# will occur if both are defined. Groups is legacy but is fully supported with
# no plans to deprecate it.
#
# ~~~~~
#
# The matrix declares valid combinations of models that can run concurrently.
# When a model is requested, the solver finds the cheapest way to make it
# available by evicting as few (and least costly) running models as possible.
+59 -9
View File
@@ -46,13 +46,31 @@ fi
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
# to enable easy container builds on forked repos
# LS_REPO is the destination of the built container image — defaults to the
# current GitHub repository so forked CI builds publish to the fork's own
# ghcr.io namespace without code changes.
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
# LS_BINARY_REPO is where the llama-swap release tarball is downloaded
# from. Decoupled from LS_REPO so forks (which usually have no releases of
# their own) can still build a container by pulling the canonical binary
# from upstream. Override via the LS_BINARY_REPO env var when you maintain
# fork-side releases.
LS_BINARY_REPO=${LS_BINARY_REPO:-mostlygeek/llama-swap}
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
# have to strip out the 'v' due to .tar.gz file naming.
# Authenticated request — unauth'd github.com API is 60/hr per IP and GHA
# runners share IPs, so the call regularly returns rate-limit JSON and
# `.tag_name` then resolves to "null", producing a bogus `vnull` URL below.
LS_VER=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/repos/${LS_BINARY_REPO}/releases/latest" \
| jq -r .tag_name | sed 's/v//')
if [[ -z "$LS_VER" || "$LS_VER" == "null" ]]; then
log_info "Error: could not resolve latest llama-swap release tag from ${LS_BINARY_REPO}"
exit 1
fi
# Fetches the most recent llama.cpp tag matching the given prefix
# Handles pagination to search beyond the first 100 results
@@ -126,6 +144,25 @@ if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
exit 0
fi
# cpu is the only backend with a multi-arch upstream base
# (ghcr.io/ggml-org/llama.cpp:server-bXXXX ships amd64+arm64); GPU backends
# are amd64-only and stay on the original `docker build` path so the
# sd-server layer can still FROM the just-built image via the local
# dockerd image store (buildx's container driver has a separate store
# that doesn't share with dockerd, which breaks the sd build).
if [ "$ARCH" == "cpu" ]; then
if [ "$PUSH_IMAGES" == "true" ]; then
BUILDX_FLAGS="--push --platform linux/amd64,linux/arm64"
else
# Smoke build: validate both platforms but emit no output. buildx
# on the docker-container driver defaults to cacheonly when
# neither --push nor --load is given, so each arch fully builds
# and a regression in either fails CI — without materializing the
# image or needing to --load (which is multi-arch-incompatible).
BUILDX_FLAGS="--platform linux/amd64,linux/arm64"
fi
fi
for CONTAINER_TYPE in non-root root; do
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
@@ -142,11 +179,23 @@ for CONTAINER_TYPE in non-root root; do
fi
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
if [ "$ARCH" == "cpu" ]; then
docker buildx build $BUILDX_FLAGS --provenance=false \
-f llama-swap.Containerfile \
--build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_BINARY_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} \
--build-arg BASE_IMAGE=${BASE_IMAGE} \
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
else
docker build --provenance=false -f llama-swap.Containerfile \
--build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_BINARY_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} \
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
fi
# For architectures with stable-diffusion.cpp support, layer sd-server on top
# For architectures with stable-diffusion.cpp support, layer sd-server on top.
# Stays on `docker build` so the base resolves from local dockerd.
case "$ARCH" in
"musa" | "vulkan")
log_info "Adding sd-server to $CONTAINER_TAG"
@@ -157,7 +206,8 @@ for CONTAINER_TYPE in non-root root; do
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
esac
if [ "$PUSH_IMAGES" == "true" ]; then
# cpu builds push inline via buildx --push; all other archs push here.
if [ "$ARCH" != "cpu" ] && [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
+6 -3
View File
@@ -3,6 +3,9 @@ ARG BASE_TAG=server-cuda
FROM ${BASE_IMAGE}:${BASE_TAG}
# has to be after the FROM
# TARGETARCH is auto-set by `docker buildx build --platform …` (amd64/arm64);
# falls back to amd64 when an older `docker build` runs without buildx.
ARG TARGETARCH=amd64
ARG LS_VER=170
ARG LS_REPO=mostlygeek/llama-swap
@@ -34,9 +37,9 @@ WORKDIR /app
ENV PATH="/app:${PATH}"
RUN \
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_${TARGETARCH}.tar.gz"
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
+15 -3
View File
@@ -146,6 +146,18 @@ metricsMaxInMemory: 1000
# - set to 0 to disable
captureBuffer: 15
# performance: configuration for system monitoring statistics
# - timing values are duration strings like 1s, 1h30m, 90m, 2h10s, etc.
performance:
# disabled: boolean
# - default: false
enable: true
# every: delay between polling for new performance statistics
# - default: 5s
# - minimum duration 5s
every: 5s
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -187,8 +199,7 @@ globalTTL: 0
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
/path/to/llama-server/llama-server-ec9e0301 --port ${PORT}
"default_ctx": 4096
@@ -348,7 +359,8 @@ models:
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp},
context=${default_ctx}"
a_list:
- 1
File diff suppressed because it is too large Load Diff
+264
View File
@@ -0,0 +1,264 @@
# New Router Migration TODO
This document tracks the work needed for [cmd/newrouter/main.go](../cmd/newrouter/main.go) and [internal/router/](../internal/router/) to reach feature parity with the legacy entrypoint at [llama-swap.go](../llama-swap.go) plus [proxy/proxymanager.go](../proxy/proxymanager.go).
The work is split into phases so each can land and be tested independently. Earlier phases unblock later ones.
## Current state (newrouter)
`cmd/newrouter` already supports:
- Loading config via `-config`
- Selecting Matrix vs Group router based on config
- Peer routing fallback
- Plain HTTP listen (`-listen`)
- Graceful shutdown on `SIGINT` / `SIGTERM`
- Model extraction from JSON body, query string, and form bodies (see [router.go:88](../internal/router/router.go#L88))
- `Server.ServeHTTP` dispatches a single request to peer or local router based on the requested model
Everything below is missing or only partially implemented.
---
## Phase 1 — Package relocation -- Completed.
Goal: move shared infrastructure packages out from under `proxy/` so the new router does not depend on the legacy proxy tree. This is a prerequisite for retiring `proxy/` in Phase 8.
---
## Phase 2 — Server lifecycle parity -- Completed.
Goal: make `cmd/newrouter` a drop-in replacement for the legacy binary's process model, _without_ yet adding any extra HTTP endpoints.
---
## Phase 3 — `internal/chain` package -- Completed.
API: `chain.New(mws...).Then(final)` for ServeMux registration; `Append` returns an extended Chain without mutating the receiver, so a base stack (auth/CORS) can be reused across many routes with per-route additions.
---
## Phase 4 — `internal/server` package scaffolding (ProxyManager replacement) -- Completed.
Goal: build the [internal/server](../internal/server/) package so it can stand in for [proxy.ProxyManager](../proxy/proxymanager.go#L67) — the mux, lifecycle, model dispatch, custom endpoints, request filters, auth/CORS, and upstream passthrough. After this phase, `cmd/newrouter/main.go` constructs a `server.Server` instead of a bare `router.Server`.
The legacy `ProxyManager` collapses three concerns into one struct: the HTTP mux, the model→process router, and the cross-cutting services (loggers, metrics, perf, inflight counter, version). The new layout keeps the `router.Router` implementations focused on model dispatch and lets `internal/server.Server` own the mux and all cross-cutting middleware. `server.Server` builds the `local` and `peer` routers directly and dispatches between them itself, so it fully **supersedes `internal/router.Server`** — see the cleanup item below.
The phase is split into sub-phases that can land and be tested independently:
| Sub-phase | Scope |
| --------- | -------------------------------------------------------------------------- |
| 4a | package scaffolding — struct, `New`, `ServeHTTP`, `Shutdown`, model routes |
| 4b | custom (non-model-dispatched) HTTP endpoints |
| 4c | request-body filter middleware |
| 4d | auth & CORS middleware |
| 4e | upstream passthrough |
The package is split by concern across stub files already in place:
| File | Responsibility | Filled in by |
| ------------ | ----------------------------------------------- | ---------------------- |
| `server.go` | `Server` struct, `New`, `ServeHTTP`, `Shutdown` | 4a |
| `log.go` | `muxlog` combined logger; `/logs` handlers | 4a |
| `auth.go` | `CreateAuthMiddleware` | 4d |
| `filters.go` | request-body filter middleware | 4c |
| `api.go` | llama-swap-specific API handlers | 4b / Phase 5 / Phase 6 |
| `ui.go` | embedded UI serving | Phase 7 |
### Phase 4a — package scaffolding -- Completed.
`server.Server` owns the mux, the `local`/`peer` routers, `muxlog`, and a
shutdown context. `New` builds the routers, registers all model-dispatched
routes on a stdlib `http.ServeMux`, and wraps the mux with the global CORS
middleware. `localPeerHandler` resolves the model once via `router.FetchModel`
and dispatches to `local` or `peer`. `Shutdown` stops both routers in parallel
and is idempotent. `cmd/newrouter/main.go` now constructs `server.New(...)`;
`internal/router/server.go` and `server_test.go` were removed as dead code.
### Phase 4b — Custom HTTP endpoints -- Completed.
`GET /v1/models` (local + peer models, aliases, metadata), `GET /health`,
`GET /wol-health`, and `GET /``/ui` are registered. `GET /favicon.ico` is
deferred to Phase 7 since it requires the embedded UI filesystem.
### Phase 4c — Request-body filters -- Completed.
`CreateFilterMiddleware` (in `filters.go`) applies `UseModelName`,
`StripParams`, `SetParams`, and `SetParamsByID` to JSON requests, then
re-attaches the body with `Content-Length` / `Transfer-Encoding` cleanup.
### Phase 4d — Auth & CORS -- Completed.
`CreateAuthMiddleware` validates API keys (Bearer / Basic / `x-api-key`) and
strips the headers before upstream. `CreateCORSMiddleware` answers OPTIONS
preflight; `/v1/models` echoes the `Origin`.
### Phase 4e — Upstream passthrough -- Completed.
`GET /upstream``/ui/models`, and `/upstream/<model>/<path>` proxies to the
resolved model with multi-segment name resolution, canonical-form redirect
(301/308), and prefix stripping.
---
## Phase 5 — Operations endpoints -- Completed.
A new `router.LocalRouter` interface embeds `Router` and adds `RunningModels()`
and `Unload(timeout, models...)`, both implemented once on `baseRouter` so
`Group` and `Matrix` share them — the legacy matrix/group divergence at
[proxymanager.go:1167](../proxy/proxymanager.go#L1167) collapses since
`baseRouter` already unifies process storage. `Peer` does not implement it;
`Server.local` is typed `LocalRouter`, `Server.peer` stays `Router`.
`GET /unload` stops every local process; `GET /running` lists non-stopped
processes joined against config for `cmd`/`proxy`/`ttl`/`name`/`description`.
`startPreload` fires a background `GET /` at each `Hooks.OnStartup.Preload`
model and emits `shared.ModelPreloadedEvent`.
---
## Phase 6 — Metrics, perf, and SSE -- Completed.
`perf.Monitor` is created and started in `cmd/newrouter/main.go` (it outlives
config reloads via `UpdateConfig`) and passed into `server.New`. `GET /metrics`
serves `perf.Monitor.MetricsHandler()` output, 503 when disabled.
`internal/process` emits `shared.ProcessStateChangeEvent` from `setState`.
`server.inflightCounter` (atomic) + `CreateInflightMiddleware` track
model-dispatched requests and emit `InFlightRequestsEvent`. `metricsMonitor`
(in `metrics.go`) parses token usage from upstream responses via
`CreateMetricsMiddleware`.
The `/api` group (API-key protected) is registered: `POST /api/models/unload`,
`POST /api/models/unload/{model...}`, `GET /api/events` (SSE: `modelStatus` /
`logData` / `metrics` / `inflight`), `GET /api/metrics`, `GET /api/performance`
(`?after=` RFC3339 filter), `GET /api/version`. `GET /api/captures/{id}`
returns 501 until 6f.
### Phase 6f — Request/response captures -- Completed.
`proxy/cache` moved to `internal/cache`. `metricsMonitor` stores zstd+CBOR
`ReqRespCapture` records in a sized `cache.Cache` (`captureBuffer` MB, 0
disables). `CreateMetricsMiddleware` buffers request body/headers before
dispatch; `record` builds the capture per a `captureFieldsByPath` table
(`captures.go`) that trims large audio/image payloads, defaulting JSON routes
to `captureAll`. `GET /api/captures/{id}` decompresses and returns the capture;
`getMetrics` resolves `HasCapture` against the cache.
---
## Phase 7 — UI serving -- Completed.
`internal/server/ui.go` embeds `ui_dist` and serves it. `GET /ui/` is
brotli/gzip-aware via `serveCompressedFile`; unknown paths without a file
extension fall back to `index.html` for SPA routing. `GET /favicon.ico` serves
from the same embedded FS. The Makefile `ui` target copies the vite build into
`internal/server/ui_dist`; a committed `placeholder.txt` keeps the embed valid
before a build runs.
---
## Phase 8a - Review Part I
- [x] All functionality from the proxy package has been migrated in the above phases — with the remaining gaps listed in Phase 8b
- [x] Test coverage at or exceeds the level from the proxy package — `internal/server` now at 76.6% vs 73.9% (`proxy`)
### Findings
**Gap 1 — Request logging middleware missing -- Resolved.**
`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records one
access-log line per request to `s.proxylog` in the legacy format
`clientIP "METHOD PATH PROTO" status bodySize "UA" duration`, skipping
`/wol-health`, `/api/performance`, and `/metrics`. A `statusRecorder` captures
the status/body size (forwarding `Flush` for SSE) and `clientIP` honours
`X-Forwarded-For` / `X-Real-IP`. It is wired as the outermost middleware in
`routes()`, wrapping the CORS layer.
**Gap 2 — Per-model log streaming not supported -- Resolved **
`Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) only handles `""`, `"proxy"`, and `"upstream"`. The legacy `ProxyManager.getLogger` ([proxymanager_loghandlers.go:92](../proxy/proxymanager_loghandlers.go#L92)) additionally resolves a model ID against the active process groups / matrix and returns that process's logger. Callers of `GET /logs/stream/<modelID>` will get a 400 instead of the model's live log stream.
**Gap 3 — `UseModelName` not applied to multipart form endpoints -- Resolved.**
`CreateFormFilterMiddleware` ([filters.go](../internal/server/filters.go)) parses
`multipart/form-data` requests, rewrites the `model` field with `UseModelName`,
reconstructs the body via `rewriteMultipartModel`, and re-attaches it with
`Content-Type` / `Content-Length` cleanup. It runs in `modelChain` after the
JSON `filterMW`; each is a no-op for the other's Content-Type. Audio
transcription (`/v1/audio/transcriptions`) and image edit (`/v1/images/edits`)
now honour `use_model_name`.
**Coverage gaps (0 % functions) -- Resolved.**
The functions previously at 0 % (`handleListModels`, `handleMetrics`,
`handleRootRedirect`, `handleUpstreamRedirect`, `handleUpstream`,
`findModelInPath`, `handleAPICapture`, `handleAPIUnloadAll`,
`handleAPIUnloadModel`, `CreateAuthMiddleware`, `extractAPIKey`,
`handleLogStream`, `applyFilters`, `decompressBody`, `filterAcceptEncoding`,
`handleUI`, `handleFavicon`) now have tests across `auth_test.go`, `api_test.go`,
`filters_test.go`, `log_test.go`, and `extras_test.go`.
---
### Phase 8b - Fill gaps discovered in Phase 8a
- [x] **Add request-log middleware**`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records `clientIP "METHOD PATH PROTO" status bodySize "UA" duration` to `s.proxylog`, skips `/wol-health` / `/api/performance` / `/metrics`, and is wired as the outermost middleware in `routes()`.
- [x] **Extend `getLogger` with model-ID resolution** — add a `default:` branch to `Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) that resolves the ID via `s.local` (using a new `LocalRouter.GetProcess(name)` method or equivalent) and returns that process's `Logger()`. Match the fallback behaviour: return a 400 with `"invalid logger. Use 'proxy', 'upstream' or a model's ID"` when not found.
- [x] **`UseModelName` rewrite for multipart endpoints** — `CreateFormFilterMiddleware` parses `multipart/form-data`, rewrites the `model` field according to `UseModelName`, reconstructs the body, and updates `Content-Type` / `Content-Length`. It is wired into `modelChain` after the JSON filter.
- [x] **Raise test coverage to ≥ 74 %**`internal/server` now at 76.1%; tests added for every 0 % function across `auth_test.go`, `api_test.go`, `filters_test.go`, `log_test.go`, and `extras_test.go`.
---
## Phase 8c - Review Part II (entrypoint comparison)
A second pass comparing [cmd/newrouter/main.go](../cmd/newrouter/main.go) against
the legacy [llama-swap.go](../llama-swap.go) + [proxy.New](../proxy/proxymanager.go#L104)
surfaced four more gaps, all in logger setup.
**Gap 4 — `LogToStdout` config ignored -- Resolved.**
`cmd/newrouter/main.go` previously hardcoded `proxyLog` / `upstreamLog` to
`os.Stdout`, and the old `muxlog()` helper built a Monitor that nothing wrote
into — so `logToStdout` had no effect and `/logs` (combined history) was always
empty. `server.NewLoggers` ([log.go](../internal/server/log.go)) now replicates
the legacy switch: `proxy` / `upstream` monitors feed `muxLog` (or `io.Discard`)
per `none` / `both` / `upstream` / `proxy`, so `muxLog` accumulates the combined
history. `server.New` takes `muxlog` as a parameter. The loggers outlive config
reloads, so a `LogToStdout` change requires a restart to take effect.
**Gap 5 — `LogTimeFormat` config ignored -- Resolved.**
`cmd/newrouter/main.go` now maps `cfg.LogTimeFormat` to a Go time layout via the
`logTimeFormats` table and applies it (alongside log level) to the proxy and
upstream monitors in `applyLogSettings`, re-applied on config reload.
**Gap 6 — `LogRequests` deprecation warning missing.**
The legacy [proxymanager.go:127](../proxy/proxymanager.go#L127) warns when the
deprecated `logRequests` config key is set. `cmd/newrouter` does not. Low
priority — left open.
**Gap 7 — PID debug log missing -- Resolved.**
`cmd/newrouter/main.go` now logs `PID: %d` at debug level after `applyLogSettings`,
matching [llama-swap.go:71](../llama-swap.go#L71).
---
## Phase X (tbd) — Cutover
- [ ] Swap `llama-swap.go` to delegate to `cmd/newrouter` (or rename newrouter to be the primary entrypoint)
- [ ] Update `Makefile` build targets
- [ ] Update docs / README references to the legacy binary
- [ ] Remove `proxy/proxymanager*.go` and `gin-gonic` dependency once nothing imports them
- [ ] Run `make test-all` and confirm concurrency suite still passes against the new entrypoint
---
## Cross-cutting concerns to keep in mind
- **Single body read**: legacy and newrouter both buffer the request body once. When adding filters (Phase 4c), make sure the buffered bytes flow through `Content-Length` / `transfer-encoding` cleanup as in [proxymanager.go:872](../proxy/proxymanager.go#L872).
- **Streaming flag in context**: legacy stashes `streaming` and `model` under `proxyCtxKey`. The new router uses `ModelKey` / `ModelIDKey` — pick one set of keys and use them consistently for metrics + log handlers.
- **Matrix vs Group divergence**: any handler that calls `swapProcessGroup` or `findGroupByModelName` in the legacy needs a matrix branch too. The new router's `Router` interface already abstracts this — preserve that abstraction rather than reintroducing the branch in every handler.
- **Shutdown ordering**: `httpServer.Shutdown` must drain inflight requests _before_ `Server.Shutdown` tears down processes, otherwise inflight requests 502. Current newrouter ordering at [main.go:87](../cmd/newrouter/main.go#L87) is correct — keep it.
+31 -2
View File
@@ -4,23 +4,40 @@ go 1.26.1
require (
github.com/billziss-gh/golib v0.2.0
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fxamacker/cbor/v2 v2.9.1
github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0
github.com/shirou/gopsutil/v4 v4.26.4
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/sync v0.20.0
golang.org/x/sys v0.41.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ebitengine/purego v0.10.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
@@ -28,20 +45,32 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+70 -6
View File
@@ -1,9 +1,31 @@
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
@@ -11,6 +33,10 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -19,6 +45,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -29,8 +57,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
@@ -42,17 +71,37 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY=
github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -63,8 +112,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -75,26 +125,40 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
View File
+63
View File
@@ -0,0 +1,63 @@
// Package chain composes http.Handler middleware into a single handler.
//
// A Middleware wraps a downstream http.Handler and may run logic before or
// after delegating to it, or short-circuit by not calling next at all
// (e.g. auth failure, CORS preflight).
package chain
import "net/http"
// Middleware wraps an http.Handler with cross-cutting behavior. It receives
// the next handler in the chain and returns a handler that may call next,
// modify the request/response around it, or short-circuit.
type Middleware func(next http.Handler) http.Handler
// Chain is a reusable middleware stack. Build it once with New (and optionally
// extend per-route with Append), then call Then to wrap each terminal handler
// when registering routes against an http.ServeMux:
//
// api := chain.New(authMW, corsMW)
// mux.Handle("/v1/chat/completions", api.Then(dispatch))
// mux.Handle("/v1/embeddings", api.Append(filters).Then(dispatch))
//
// Middlewares execute left-to-right: mws[0] runs first and may call into
// mws[1], and so on, with the terminal handler invoked last. A middleware
// that does not call next short-circuits the remainder of the chain.
// A zero Chain is valid and applies no middleware.
type Chain struct {
mws []Middleware
}
// New returns a Chain that applies mws left-to-right around any terminal
// handler passed to Then.
func New(mws ...Middleware) Chain {
cp := make([]Middleware, len(mws))
copy(cp, mws)
return Chain{mws: cp}
}
// Append returns a new Chain with mws added after the existing middleware.
// The receiver is not modified, so a base Chain can be safely reused across
// multiple routes that each need different per-route additions.
func (c Chain) Append(mws ...Middleware) Chain {
out := make([]Middleware, 0, len(c.mws)+len(mws))
out = append(out, c.mws...)
out = append(out, mws...)
return Chain{mws: out}
}
// Then wraps final with the chain's middleware and returns the resulting
// handler, suitable for passing to http.ServeMux.Handle. With an empty chain,
// Then returns final unchanged.
func (c Chain) Then(final http.Handler) http.Handler {
h := final
for i := len(c.mws) - 1; i >= 0; i-- {
h = c.mws[i](h)
}
return h
}
// ThenFunc is shorthand for Then(http.HandlerFunc(f)).
func (c Chain) ThenFunc(f http.HandlerFunc) http.Handler {
return c.Then(f)
}
+205
View File
@@ -0,0 +1,205 @@
package chain
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// recordingMiddleware appends tag before calling next and "-after-"+tag after.
func recordingMiddleware(tag string, log *[]string) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
*log = append(*log, tag)
next.ServeHTTP(w, r)
*log = append(*log, "after-"+tag)
})
}
}
func TestChain_HandlersExecuteInDeclaredOrder(t *testing.T) {
var log []string
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "final")
})
h := New(
recordingMiddleware("a", &log),
recordingMiddleware("b", &log),
recordingMiddleware("c", &log),
).Then(final)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
h.ServeHTTP(rec, req)
want := []string{"a", "b", "c", "final", "after-c", "after-b", "after-a"}
if !equal(log, want) {
t.Fatalf("execution order mismatch:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_ShortCircuitsWhenMiddlewareDoesNotCallNext(t *testing.T) {
var log []string
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "final")
})
gate := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "gate")
w.WriteHeader(http.StatusUnauthorized)
})
}
h := New(
recordingMiddleware("outer", &log),
gate,
recordingMiddleware("inner", &log),
).Then(final)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
want := []string{"outer", "gate", "after-outer"}
if !equal(log, want) {
t.Fatalf("short-circuit order mismatch:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_EarlyWritesAreVisibleToLaterMiddleware(t *testing.T) {
header := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Set-By", "outer")
_, _ = io.WriteString(w, "outer:")
next.ServeHTTP(w, r)
})
}
inner := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The outer middleware already set the header; we should see it.
if got := w.Header().Get("X-Set-By"); got != "outer" {
_, _ = io.WriteString(w, "missing-header;")
}
_, _ = io.WriteString(w, "inner:")
next.ServeHTTP(w, r)
})
}
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "final")
})
h := New(header, inner).Then(final)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
body, _ := io.ReadAll(rec.Body)
if got := string(body); !strings.Contains(got, "outer:inner:final") {
t.Fatalf("body: got %q, want it to contain %q", got, "outer:inner:final")
}
if got := rec.Header().Get("X-Set-By"); got != "outer" {
t.Fatalf("header X-Set-By: got %q, want %q", got, "outer")
}
}
func TestChain_ReusableAcrossRoutesViaThen(t *testing.T) {
var log []string
base := New(
recordingMiddleware("auth", &log),
recordingMiddleware("cors", &log),
)
mux := http.NewServeMux()
mux.Handle("/a", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "handler-a")
}))
mux.Handle("/b", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "handler-b")
}))
srv := httptest.NewServer(mux)
defer srv.Close()
for _, path := range []string{"/a", "/b"} {
resp, err := http.Get(srv.URL + path)
if err != nil {
t.Fatalf("GET %s: %v", path, err)
}
resp.Body.Close()
}
want := []string{
"auth", "cors", "handler-a", "after-cors", "after-auth",
"auth", "cors", "handler-b", "after-cors", "after-auth",
}
if !equal(log, want) {
t.Fatalf("reusable chain order mismatch:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_AppendDoesNotMutateReceiver(t *testing.T) {
var log []string
base := New(recordingMiddleware("base", &log))
extended := base.Append(recordingMiddleware("extra", &log))
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "final")
})
// Run extended first to surface any aliasing of the underlying slice.
rec := httptest.NewRecorder()
extended.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
rec = httptest.NewRecorder()
base.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
want := []string{
"base", "extra", "final", "after-extra", "after-base",
"base", "final", "after-base",
}
if !equal(log, want) {
t.Fatalf("Append must not mutate the receiver:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_ZeroValueAndEmptyThenAreIdentity(t *testing.T) {
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
for name, c := range map[string]Chain{
"zero": {},
"empty": New(),
} {
t.Run(name, func(t *testing.T) {
h := c.Then(final)
if _, ok := h.(http.HandlerFunc); !ok {
t.Fatalf("expected http.HandlerFunc identity, got %T", h)
}
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusTeapot {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusTeapot)
}
})
}
}
func equal(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
@@ -9,6 +9,7 @@ import (
"runtime"
"sort"
"strings"
"time"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
@@ -124,6 +125,7 @@ type Config struct {
LogToStdout string `yaml:"logToStdout"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
CaptureBuffer int `yaml:"captureBuffer"`
Performance PerformanceConfig `yaml:"performance"`
GlobalTTL int `yaml:"globalTTL"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
@@ -220,6 +222,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.HealthCheckTimeout = 15
}
// Apply defaults for performance config when section is missing
if config.Performance.Every == 0 {
config.Performance.Every = 5 * time.Second
}
if err = config.Performance.Validate(); err != nil {
return Config{}, fmt.Errorf("performance: %w", err)
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
@@ -262,6 +272,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
// Strip comments from command fields
modelConfig.Cmd = StripComments(modelConfig.Cmd)
@@ -7,6 +7,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -188,47 +189,54 @@ groups:
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
"model2": {
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
Performance: PerformanceConfig{
Every: 5 * time.Second,
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
@@ -7,6 +7,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -175,49 +176,56 @@ groups:
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
"model2": {
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
Performance: PerformanceConfig{
Every: 5 * time.Second,
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
@@ -54,6 +54,9 @@ type ModelConfig struct {
// Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"`
// Copy of HealthCheckTimeout from global config
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
+34
View File
@@ -0,0 +1,34 @@
package config
import (
"fmt"
"time"
)
// PerformanceConfig holds configuration for system performance monitoring
type PerformanceConfig struct {
Disabled bool `yaml:"disabled"`
Every time.Duration `yaml:"every"`
}
func (p *PerformanceConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPerformanceConfig PerformanceConfig
defaults := rawPerformanceConfig{
Every: 5 * time.Second,
}
if err := unmarshal(&defaults); err != nil {
return err
}
*p = PerformanceConfig(defaults)
return nil
}
// Validate checks the PerformanceConfig values and returns an error if invalid
func (p *PerformanceConfig) Validate() error {
if p.Every < 5*time.Second {
return fmt.Errorf("every must be at least 5s, got %v", p.Every)
}
return nil
}
+98
View File
@@ -0,0 +1,98 @@
package config
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPerformanceConfig_Defaults(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// When performance section is missing, defaults should be applied
assert.False(t, config.Performance.Disabled)
assert.Equal(t, 5*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_CustomValues(t *testing.T) {
content := `
performance:
enable: true
every: 30s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.False(t, config.Performance.Disabled)
assert.Equal(t, 30*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_Disabled(t *testing.T) {
content := `
performance:
disabled: true
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.True(t, config.Performance.Disabled)
// Duration defaults should still apply
assert.Equal(t, 5*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_PartialValues(t *testing.T) {
content := `
performance:
every: 10s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// enable should default to true
assert.False(t, config.Performance.Disabled)
assert.Equal(t, 10*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_InvalidEvery(t *testing.T) {
content := `
performance:
every: 4s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "every must be at least 5s")
}
func TestPerformanceConfig_ComplexDurations(t *testing.T) {
content := `
performance:
every: 1m30s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 90*time.Second, config.Performance.Every)
}
@@ -1,54 +1,54 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
*/
func BenchmarkSubscribeConcurrent(b *testing.B) {
d := NewDispatcher()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
unsub := Subscribe(d, func(ev MyEvent1) {})
unsub()
}
})
}
func TestDefaultPublish(t *testing.T) {
var wg sync.WaitGroup
// Subscribe
var count int64
defer On(func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
defer OnType(TypeEvent1, func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
// Publish
wg.Add(4)
Emit(MyEvent1{})
Emit(MyEvent1{})
// Wait and check
wg.Wait()
assert.Equal(t, int64(4), count)
}
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
*/
func BenchmarkSubscribeConcurrent(b *testing.B) {
d := NewDispatcher()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
unsub := Subscribe(d, func(ev MyEvent1) {})
unsub()
}
})
}
func TestDefaultPublish(t *testing.T) {
var wg sync.WaitGroup
// Subscribe
var count int64
defer On(func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
defer OnType(TypeEvent1, func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
// Publish
wg.Add(4)
Emit(MyEvent1{})
Emit(MyEvent1{})
// Wait and check
wg.Wait()
assert.Equal(t, int64(4), count)
}
@@ -1,324 +1,324 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPublish(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
// Subscribe, must be received in order
var count int64
defer Subscribe(d, func(ev MyEvent1) {
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
wg.Done()
})()
// Publish
wg.Add(3)
Publish(d, MyEvent1{Number: 1})
Publish(d, MyEvent1{Number: 2})
Publish(d, MyEvent1{Number: 3})
// Wait and check
wg.Wait()
assert.Equal(t, int64(3), count)
}
func TestUnsubscribe(t *testing.T) {
d := NewDispatcher()
assert.Equal(t, 0, d.count(TypeEvent1))
unsubscribe := Subscribe(d, func(ev MyEvent1) {
// Nothing
})
assert.Equal(t, 1, d.count(TypeEvent1))
unsubscribe()
assert.Equal(t, 0, d.count(TypeEvent1))
}
func TestConcurrent(t *testing.T) {
const max = 1000000
var count int64
var wg sync.WaitGroup
wg.Add(1)
d := NewDispatcher()
defer Subscribe(d, func(ev MyEvent1) {
if current := atomic.AddInt64(&count, 1); current == max {
wg.Done()
}
})()
// Asynchronously publish
go func() {
for i := 0; i < max; i++ {
Publish(d, MyEvent1{})
}
}()
defer Subscribe(d, func(ev MyEvent1) {
// Subscriber that does nothing
})()
wg.Wait()
assert.Equal(t, max, int(count))
}
func TestSubscribeDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestPublishDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
Publish(d, MyEvent1{})
})
}
func TestCloseDispatcher(t *testing.T) {
d := NewDispatcher()
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
assert.NoError(t, d.Close())
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestMatrix(t *testing.T) {
const amount = 1000
for _, subs := range []int{1, 10, 100} {
for _, topics := range []int{1, 10} {
expected := subs * topics * amount
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
var count atomic.Int64
var wg sync.WaitGroup
wg.Add(expected)
d := NewDispatcher()
for i := 0; i < subs; i++ {
for id := 0; id < topics; id++ {
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
count.Add(1)
wg.Done()
})()
}
}
for n := 0; n < amount; n++ {
for id := 0; id < topics; id++ {
go Publish(d, MyEvent3{ID: id})
}
}
wg.Wait()
assert.Equal(t, expected, int(count.Load()))
})
}
}
}
func TestConcurrentSubscriptionRace(t *testing.T) {
// This test specifically targets the race condition that occurs when multiple
// goroutines try to subscribe to different event types simultaneously.
// Without the CAS loop, subscriptions could be lost due to registry corruption.
const numGoroutines = 100
const numEventTypes = 50
d := NewDispatcher()
defer d.Close()
var wg sync.WaitGroup
var receivedCount int64
var subscribedTypes sync.Map // Thread-safe map
wg.Add(numGoroutines)
// Start multiple goroutines that subscribe to different event types concurrently
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
defer wg.Done()
// Each goroutine subscribes to a unique event type
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
// Subscribe to the event type
SubscribeTo(d, eventType, func(ev MyEvent3) {
atomic.AddInt64(&receivedCount, 1)
})
// Record that this type was subscribed
subscribedTypes.Store(eventType, true)
}(i)
}
// Wait for all subscriptions to complete
wg.Wait()
// Count the number of unique event types subscribed
expectedTypes := 0
subscribedTypes.Range(func(key, value interface{}) bool {
expectedTypes++
return true
})
// Small delay to ensure all subscriptions are fully processed
time.Sleep(10 * time.Millisecond)
// Publish events to each subscribed type
subscribedTypes.Range(func(key, value interface{}) bool {
eventType := key.(uint32)
Publish(d, MyEvent3{ID: int(eventType)})
return true
})
// Wait for all events to be processed
time.Sleep(50 * time.Millisecond)
// Verify that we received at least the expected number of events
// (there might be more if multiple goroutines subscribed to the same event type)
received := atomic.LoadInt64(&receivedCount)
assert.GreaterOrEqual(t, int(received), expectedTypes,
"Should have received at least %d events, got %d", expectedTypes, received)
// Verify that we have the expected number of unique event types
assert.Equal(t, numEventTypes, expectedTypes,
"Should have exactly %d unique event types", numEventTypes)
}
func TestConcurrentHandlerRegistration(t *testing.T) {
const numGoroutines = 100
// Test concurrent subscriptions to the same event type
t.Run("SameEventType", func(t *testing.T) {
d := NewDispatcher()
var handlerCount int64
var wg sync.WaitGroup
// Start multiple goroutines subscribing to the same event type (0x1)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
atomic.AddInt64(&handlerCount, 1)
})
}()
}
wg.Wait()
// Verify all handlers were registered by publishing an event
atomic.StoreInt64(&handlerCount, 0)
Publish(d, MyEvent1{})
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
"Not all handlers were registered due to race condition")
})
// Test concurrent subscriptions to different event types
t.Run("DifferentEventTypes", func(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
receivedEvents := make(map[uint32]*int64)
// Create multiple event types and subscribe concurrently
for i := 0; i < numGoroutines; i++ {
eventType := uint32(100 + i)
counter := new(int64)
receivedEvents[eventType] = counter
wg.Add(1)
go func(et uint32, cnt *int64) {
defer wg.Done()
SubscribeTo(d, et, func(ev MyEvent3) {
atomic.AddInt64(cnt, 1)
})
}(eventType, counter)
}
wg.Wait()
// Publish events to all types
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
Publish(d, MyEvent3{ID: int(eventType)})
}
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
// Verify all event types received their events
for eventType, counter := range receivedEvents {
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
"Event type %d did not receive its event", eventType)
}
})
}
func TestBackpressure(t *testing.T) {
d := NewDispatcher()
d.maxQueue = 10
var processedCount int64
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
atomic.AddInt64(&processedCount, 1)
})
defer unsub()
const eventsToPublish = 1000
for i := 0; i < eventsToPublish; i++ {
Publish(d, MyEvent3{ID: 0x200})
}
time.Sleep(100 * time.Millisecond)
// Verify all events were eventually processed
finalProcessed := atomic.LoadInt64(&processedCount)
assert.Equal(t, int64(eventsToPublish), finalProcessed)
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
}
// ------------------------------------- Test Events -------------------------------------
const (
TypeEvent1 = 0x1
TypeEvent2 = 0x2
)
type MyEvent1 struct {
Number int
}
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
type MyEvent2 struct {
Text string
}
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
type MyEvent3 struct {
ID int
}
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPublish(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
// Subscribe, must be received in order
var count int64
defer Subscribe(d, func(ev MyEvent1) {
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
wg.Done()
})()
// Publish
wg.Add(3)
Publish(d, MyEvent1{Number: 1})
Publish(d, MyEvent1{Number: 2})
Publish(d, MyEvent1{Number: 3})
// Wait and check
wg.Wait()
assert.Equal(t, int64(3), count)
}
func TestUnsubscribe(t *testing.T) {
d := NewDispatcher()
assert.Equal(t, 0, d.count(TypeEvent1))
unsubscribe := Subscribe(d, func(ev MyEvent1) {
// Nothing
})
assert.Equal(t, 1, d.count(TypeEvent1))
unsubscribe()
assert.Equal(t, 0, d.count(TypeEvent1))
}
func TestConcurrent(t *testing.T) {
const max = 1000000
var count int64
var wg sync.WaitGroup
wg.Add(1)
d := NewDispatcher()
defer Subscribe(d, func(ev MyEvent1) {
if current := atomic.AddInt64(&count, 1); current == max {
wg.Done()
}
})()
// Asynchronously publish
go func() {
for i := 0; i < max; i++ {
Publish(d, MyEvent1{})
}
}()
defer Subscribe(d, func(ev MyEvent1) {
// Subscriber that does nothing
})()
wg.Wait()
assert.Equal(t, max, int(count))
}
func TestSubscribeDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestPublishDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
Publish(d, MyEvent1{})
})
}
func TestCloseDispatcher(t *testing.T) {
d := NewDispatcher()
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
assert.NoError(t, d.Close())
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestMatrix(t *testing.T) {
const amount = 1000
for _, subs := range []int{1, 10, 100} {
for _, topics := range []int{1, 10} {
expected := subs * topics * amount
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
var count atomic.Int64
var wg sync.WaitGroup
wg.Add(expected)
d := NewDispatcher()
for i := 0; i < subs; i++ {
for id := 0; id < topics; id++ {
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
count.Add(1)
wg.Done()
})()
}
}
for n := 0; n < amount; n++ {
for id := 0; id < topics; id++ {
go Publish(d, MyEvent3{ID: id})
}
}
wg.Wait()
assert.Equal(t, expected, int(count.Load()))
})
}
}
}
func TestConcurrentSubscriptionRace(t *testing.T) {
// This test specifically targets the race condition that occurs when multiple
// goroutines try to subscribe to different event types simultaneously.
// Without the CAS loop, subscriptions could be lost due to registry corruption.
const numGoroutines = 100
const numEventTypes = 50
d := NewDispatcher()
defer d.Close()
var wg sync.WaitGroup
var receivedCount int64
var subscribedTypes sync.Map // Thread-safe map
wg.Add(numGoroutines)
// Start multiple goroutines that subscribe to different event types concurrently
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
defer wg.Done()
// Each goroutine subscribes to a unique event type
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
// Subscribe to the event type
SubscribeTo(d, eventType, func(ev MyEvent3) {
atomic.AddInt64(&receivedCount, 1)
})
// Record that this type was subscribed
subscribedTypes.Store(eventType, true)
}(i)
}
// Wait for all subscriptions to complete
wg.Wait()
// Count the number of unique event types subscribed
expectedTypes := 0
subscribedTypes.Range(func(key, value interface{}) bool {
expectedTypes++
return true
})
// Small delay to ensure all subscriptions are fully processed
time.Sleep(10 * time.Millisecond)
// Publish events to each subscribed type
subscribedTypes.Range(func(key, value interface{}) bool {
eventType := key.(uint32)
Publish(d, MyEvent3{ID: int(eventType)})
return true
})
// Wait for all events to be processed
time.Sleep(50 * time.Millisecond)
// Verify that we received at least the expected number of events
// (there might be more if multiple goroutines subscribed to the same event type)
received := atomic.LoadInt64(&receivedCount)
assert.GreaterOrEqual(t, int(received), expectedTypes,
"Should have received at least %d events, got %d", expectedTypes, received)
// Verify that we have the expected number of unique event types
assert.Equal(t, numEventTypes, expectedTypes,
"Should have exactly %d unique event types", numEventTypes)
}
func TestConcurrentHandlerRegistration(t *testing.T) {
const numGoroutines = 100
// Test concurrent subscriptions to the same event type
t.Run("SameEventType", func(t *testing.T) {
d := NewDispatcher()
var handlerCount int64
var wg sync.WaitGroup
// Start multiple goroutines subscribing to the same event type (0x1)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
atomic.AddInt64(&handlerCount, 1)
})
}()
}
wg.Wait()
// Verify all handlers were registered by publishing an event
atomic.StoreInt64(&handlerCount, 0)
Publish(d, MyEvent1{})
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
"Not all handlers were registered due to race condition")
})
// Test concurrent subscriptions to different event types
t.Run("DifferentEventTypes", func(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
receivedEvents := make(map[uint32]*int64)
// Create multiple event types and subscribe concurrently
for i := 0; i < numGoroutines; i++ {
eventType := uint32(100 + i)
counter := new(int64)
receivedEvents[eventType] = counter
wg.Add(1)
go func(et uint32, cnt *int64) {
defer wg.Done()
SubscribeTo(d, et, func(ev MyEvent3) {
atomic.AddInt64(cnt, 1)
})
}(eventType, counter)
}
wg.Wait()
// Publish events to all types
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
Publish(d, MyEvent3{ID: int(eventType)})
}
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
// Verify all event types received their events
for eventType, counter := range receivedEvents {
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
"Event type %d did not receive its event", eventType)
}
})
}
func TestBackpressure(t *testing.T) {
d := NewDispatcher()
d.maxQueue = 10
var processedCount int64
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
atomic.AddInt64(&processedCount, 1)
})
defer unsub()
const eventsToPublish = 1000
for i := 0; i < eventsToPublish; i++ {
Publish(d, MyEvent3{ID: 0x200})
}
time.Sleep(100 * time.Millisecond)
// Verify all events were eventually processed
finalProcessed := atomic.LoadInt64(&processedCount)
assert.Equal(t, int64(eventsToPublish), finalProcessed)
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
}
// ------------------------------------- Test Events -------------------------------------
const (
TypeEvent1 = 0x1
TypeEvent2 = 0x2
)
type MyEvent1 struct {
Number int
}
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
type MyEvent2 struct {
Text string
}
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
type MyEvent3 struct {
ID int
}
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
@@ -1,4 +1,4 @@
package proxy
package logmon
import (
"context"
@@ -8,15 +8,25 @@ import (
"sync"
"time"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/internal/event"
)
const DataEventID = 0x04
type DataEvent struct {
Data []byte
}
func (e DataEvent) Type() uint32 {
return DataEventID
}
// circularBuffer is a fixed-size circular byte buffer that overwrites
// oldest data when full. It provides O(1) writes and O(n) reads.
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
data []byte
head int
size int
}
func newCircularBuffer(capacity int) *circularBuffer {
@@ -27,8 +37,6 @@ func newCircularBuffer(capacity int) *circularBuffer {
}
}
// Write appends bytes to the buffer, overwriting oldest data when full.
// Data is copied into the internal buffer (not stored by reference).
func (cb *circularBuffer) Write(p []byte) {
if len(p) == 0 {
return
@@ -36,7 +44,6 @@ func (cb *circularBuffer) Write(p []byte) {
cap := len(cb.data)
// If input is larger than capacity, only keep the last cap bytes
if len(p) >= cap {
copy(cb.data, p[len(p)-cap:])
cb.head = 0
@@ -44,28 +51,22 @@ func (cb *circularBuffer) Write(p []byte) {
return
}
// Calculate how much space is available from head to end of buffer
firstPart := cap - cb.head
if firstPart >= len(p) {
// All data fits without wrapping
copy(cb.data[cb.head:], p)
cb.head = (cb.head + len(p)) % cap
} else {
// Data wraps around
copy(cb.data[cb.head:], p[:firstPart])
copy(cb.data[:len(p)-firstPart], p[firstPart:])
cb.head = len(p) - firstPart
}
// Update size
cb.size += len(p)
if cb.size > cap {
cb.size = cap
}
}
// GetHistory returns all buffered data in correct order (oldest to newest).
// Returns a new slice (copy), not a view into internal buffer.
func (cb *circularBuffer) GetHistory() []byte {
if cb.size == 0 {
return nil
@@ -74,14 +75,11 @@ func (cb *circularBuffer) GetHistory() []byte {
result := make([]byte, cb.size)
cap := len(cb.data)
// Calculate start position (oldest data)
start := (cb.head - cb.size + cap) % cap
if start+cb.size <= cap {
// Data is contiguous, single copy
copy(result, cb.data[start:start+cb.size])
} else {
// Data wraps around, two copies
firstPart := cap - start
copy(result[:firstPart], cb.data[start:])
copy(result[firstPart:], cb.data[:cb.size-firstPart])
@@ -90,42 +88,38 @@ func (cb *circularBuffer) GetHistory() []byte {
return result
}
type LogLevel int
type Level int
const (
LevelDebug LogLevel = iota
LevelDebug Level = iota
LevelInfo
LevelWarn
LevelError
LogBufferSize = 100 * 1024
BufferSize = 100 * 1024
)
type LogMonitor struct {
type Monitor struct {
eventbus *event.Dispatcher
mu sync.RWMutex
buffer *circularBuffer
bufferMu sync.RWMutex
// typically this can be os.Stdout
stdout io.Writer
// logging levels
level LogLevel
prefix string
// timestamps
level Level
prefix string
timeFormat string
}
func NewLogMonitor() *LogMonitor {
return NewLogMonitorWriter(os.Stdout)
func New() *Monitor {
return NewWriter(os.Stdout)
}
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
func NewWriter(stdout io.Writer) *Monitor {
return &Monitor{
eventbus: event.NewDispatcherConfig(1000),
buffer: nil, // lazy initialized on first Write
buffer: nil,
stdout: stdout,
level: LevelInfo,
prefix: "",
@@ -133,7 +127,7 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
}
}
func (w *LogMonitor) Write(p []byte) (n int, err error) {
func (w *Monitor) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
@@ -145,19 +139,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
w.bufferMu.Lock()
if w.buffer == nil {
w.buffer = newCircularBuffer(LogBufferSize)
w.buffer = newCircularBuffer(BufferSize)
}
w.buffer.Write(p)
w.bufferMu.Unlock()
// Make a copy for broadcast to preserve immutability
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.broadcast(bufferCopy)
return n, nil
}
func (w *LogMonitor) GetHistory() []byte {
func (w *Monitor) GetHistory() []byte {
w.bufferMu.RLock()
defer w.bufferMu.RUnlock()
if w.buffer == nil {
@@ -168,41 +161,41 @@ func (w *LogMonitor) GetHistory() []byte {
// Clear releases the buffer memory, making it eligible for GC.
// The buffer will be lazily re-allocated on the next Write.
func (w *LogMonitor) Clear() {
func (w *Monitor) Clear() {
w.bufferMu.Lock()
w.buffer = nil
w.bufferMu.Unlock()
}
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
func (w *Monitor) OnLogData(callback func(data []byte)) context.CancelFunc {
return event.Subscribe(w.eventbus, func(e DataEvent) {
callback(e.Data)
})
}
func (w *LogMonitor) broadcast(msg []byte) {
event.Publish(w.eventbus, LogDataEvent{Data: msg})
func (w *Monitor) broadcast(msg []byte) {
event.Publish(w.eventbus, DataEvent{Data: msg})
}
func (w *LogMonitor) SetPrefix(prefix string) {
func (w *Monitor) SetPrefix(prefix string) {
w.mu.Lock()
defer w.mu.Unlock()
w.prefix = prefix
}
func (w *LogMonitor) SetLogLevel(level LogLevel) {
func (w *Monitor) SetLogLevel(level Level) {
w.mu.Lock()
defer w.mu.Unlock()
w.level = level
}
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
func (w *Monitor) SetLogTimeFormat(timeFormat string) {
w.mu.Lock()
defer w.mu.Unlock()
w.timeFormat = timeFormat
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
func (w *Monitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
@@ -211,49 +204,38 @@ func (w *LogMonitor) formatMessage(level string, msg string) []byte {
if w.timeFormat != "" {
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
}
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
return fmt.Appendf(nil, "%s%s[%s] %s\n", timestamp, prefix, level, msg)
}
func (w *LogMonitor) log(level LogLevel, msg string) {
func (w *Monitor) log(level Level, msg string) {
if level < w.level {
return
}
w.Write(w.formatMessage(level.String(), msg))
}
func (w *LogMonitor) Debug(msg string) {
w.log(LevelDebug, msg)
}
func (w *Monitor) Debug(msg string) { w.log(LevelDebug, msg) }
func (w *Monitor) Info(msg string) { w.log(LevelInfo, msg) }
func (w *Monitor) Warn(msg string) { w.log(LevelWarn, msg) }
func (w *Monitor) Error(msg string) { w.log(LevelError, msg) }
func (w *LogMonitor) Info(msg string) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
func (w *Monitor) Debugf(format string, args ...any) {
w.log(LevelDebug, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Infof(format string, args ...interface{}) {
func (w *Monitor) Infof(format string, args ...any) {
w.log(LevelInfo, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
func (w *Monitor) Warnf(format string, args ...any) {
w.log(LevelWarn, fmt.Sprintf(format, args...))
}
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
func (w *Monitor) Errorf(format string, args ...any) {
w.log(LevelError, fmt.Sprintf(format, args...))
}
func (l LogLevel) String() string {
func (l Level) String() string {
switch l {
case LevelDebug:
return "DEBUG"
@@ -1,4 +1,4 @@
package proxy
package logmon
import (
"bytes"
@@ -10,9 +10,8 @@ import (
)
func TestLogMonitor(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
logMonitor := NewWriter(io.Discard)
// A WaitGroup is used to wait for all the expected writes to complete
var wg sync.WaitGroup
client1Messages := make([]byte, 0)
@@ -34,10 +33,8 @@ func TestLogMonitor(t *testing.T) {
logMonitor.Write([]byte("2"))
logMonitor.Write([]byte("3"))
// wait for all writes to complete
wg.Wait()
// Check the buffer
expectedHistory := "123"
history := string(logMonitor.GetHistory())
@@ -57,14 +54,11 @@ func TestLogMonitor(t *testing.T) {
}
func TestWrite_ImmutableBuffer(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
// Prepare a message to write
msg := []byte("Hello, World!")
lenmsg := len(msg)
// Write the message to the LogMonitor
n, err := lm.Write(msg)
if err != nil {
t.Fatalf("Write failed: %v", err)
@@ -74,13 +68,10 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
}
// Change the original message
msg[0] = 'B' // This should not affect the buffer
msg[0] = 'B'
// Get the history from the LogMonitor
history := lm.GetHistory()
// Check that the history contains the original message, not the modified one
expected := []byte("Hello, World!")
if !bytes.Equal(history, expected) {
t.Errorf("Expected history to be %q, got %q", expected, history)
@@ -88,16 +79,12 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
}
func TestWrite_LogTimeFormat(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
// Enable timestamps
lm.timeFormat = time.RFC3339
// Write the message to the LogMonitor
lm.Info("Hello, World!")
// Get the history from the LogMonitor
history := lm.GetHistory()
timestamp := ""
@@ -115,48 +102,40 @@ func TestWrite_LogTimeFormat(t *testing.T) {
}
func TestCircularBuffer_WrapAround(t *testing.T) {
// Create a small buffer to test wrap-around
cb := newCircularBuffer(10)
// Write "hello" (5 bytes)
cb.Write([]byte("hello"))
if got := string(cb.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Write "world" (5 bytes) - buffer now full
cb.Write([]byte("world"))
if got := string(cb.GetHistory()); got != "helloworld" {
t.Errorf("Expected 'helloworld', got %q", got)
}
// Write "12345" (5 bytes) - should overwrite "hello"
cb.Write([]byte("12345"))
if got := string(cb.GetHistory()); got != "world12345" {
t.Errorf("Expected 'world12345', got %q", got)
}
// Write data larger than buffer capacity
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
cb.Write([]byte("abcdefghijklmnop"))
if got := string(cb.GetHistory()); got != "ghijklmnop" {
t.Errorf("Expected 'ghijklmnop', got %q", got)
}
}
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
// Test empty buffer
cb := newCircularBuffer(10)
if got := cb.GetHistory(); got != nil {
t.Errorf("Expected nil for empty buffer, got %q", got)
}
// Test exact capacity
cb.Write([]byte("1234567890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
// Test write exactly at capacity boundary
cb = newCircularBuffer(10)
cb.Write([]byte("12345"))
cb.Write([]byte("67890"))
@@ -166,19 +145,16 @@ func TestCircularBuffer_BoundaryConditions(t *testing.T) {
}
func TestLogMonitor_LazyInit(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
// Buffer should be nil before any writes
if lm.buffer != nil {
t.Error("Expected buffer to be nil before first write")
}
// GetHistory should return nil when buffer is nil
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history before first write, got %q", got)
}
// Write should lazily initialize the buffer
lm.Write([]byte("test"))
if lm.buffer == nil {
@@ -191,15 +167,13 @@ func TestLogMonitor_LazyInit(t *testing.T) {
}
func TestLogMonitor_Clear(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
// Write some data
lm.Write([]byte("hello"))
if got := string(lm.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Clear should release the buffer
lm.Clear()
if lm.buffer != nil {
@@ -212,9 +186,8 @@ func TestLogMonitor_Clear(t *testing.T) {
}
func TestLogMonitor_ClearAndReuse(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
// Write, clear, then write again
lm.Write([]byte("first"))
lm.Clear()
lm.Write([]byte("second"))
@@ -225,13 +198,12 @@ func TestLogMonitor_ClearAndReuse(t *testing.T) {
}
func BenchmarkLogMonitorWrite(b *testing.B) {
// Test data of varying sizes
smallMsg := []byte("small message\n")
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
b.Run("SmallWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(smallMsg)
@@ -239,7 +211,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
})
b.Run("MediumWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
@@ -247,7 +219,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
})
b.Run("LargeWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
lm := NewWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(largeMsg)
@@ -255,8 +227,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
})
b.Run("WithSubscribers", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Add some subscribers
lm := NewWriter(io.Discard)
for i := 0; i < 5; i++ {
lm.OnLogData(func(data []byte) {})
}
@@ -267,8 +238,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
})
b.Run("GetHistory", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Pre-populate with data
lm := NewWriter(io.Discard)
for i := 0; i < 1000; i++ {
lm.Write(mediumMsg)
}
@@ -278,39 +248,3 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
}
})
}
/*
Benchmark Results - MBP M1 Pro
Before (ring.Ring):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
| MediumWrite (241B) | 76 ns | 264 B | 2 |
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
After (circularBuffer 10KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 67 ns | 240 B | 1 |
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
After (circularBuffer 100KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|-----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 66 ns | 240 B | 1 |
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
Summary:
- GetHistory: 139x faster (10KB), 18x faster (100KB)
- Allocations: reduced from 2 to 1 across all operations
- Small/medium writes: ~1.1-1.6x faster
*/
+214
View File
@@ -0,0 +1,214 @@
package perf
import (
"encoding/json"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"time"
)
// ParseNvidiaSmiLine parses a single line from nvidia-smi CSV output.
// Format: index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw
func ParseNvidiaSmiLine(line string) *GpuStat {
fields := strings.Split(line, ",")
if len(fields) < 9 {
return nil
}
id, _ := strconv.Atoi(strings.TrimSpace(fields[0]))
name := strings.TrimSpace(fields[1])
uuid := strings.TrimSpace(fields[2])
tempC, _ := strconv.Atoi(strings.TrimSpace(fields[3]))
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[4]), 64)
memUsed, _ := strconv.Atoi(strings.TrimSpace(fields[5]))
memTotal, _ := strconv.Atoi(strings.TrimSpace(fields[6]))
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[7]), 64)
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
var memUtil float64
if memTotal > 0 {
memUtil = float64(memUsed) / float64(memTotal) * 100
}
return &GpuStat{
Timestamp: time.Now(),
ID: id,
Name: name,
UUID: uuid,
TempC: tempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtil,
MemUsedMB: memUsed,
MemTotalMB: memTotal,
FanSpeedPct: fanSpeed,
PowerDrawW: powerDraw,
}
}
// mactopOutput maps the subset of mactop's headless JSON output that is
// relevant to GpuStat. Note that mactop's memory object is whole-system memory,
// not GPU-attributed; the darwin monitor overlays ioreg's GPU-attributed
// unified memory (see overlayIoregMem) so both backends report consistent
// memory figures.
type mactopOutput struct {
SocMetrics struct {
GPUPower float64 `json:"gpu_power"`
GPUFreq int `json:"gpu_freq_mhz"`
GPUTemp float64 `json:"gpu_temp"`
} `json:"soc_metrics"`
Memory struct {
Total uint64 `json:"total"`
Used uint64 `json:"used"`
} `json:"memory"`
GPUUsage float64 `json:"gpu_usage"`
SystemInfo struct {
Name string `json:"name"`
GPUCoreCount int `json:"gpu_core_count"`
} `json:"system_info"`
Fans []struct {
RPM int `json:"rpm"`
MinRPM int `json:"min_rpm"`
MaxRPM int `json:"max_rpm"`
} `json:"fans"`
Temperatures []struct {
Group string `json:"group"`
Avg float64 `json:"avg_celsius"`
} `json:"temperatures"`
}
// ioreg output uses ` = ` (with spaces) for top-level device properties and
// `=` (no spaces) for values inside nested dictionaries such as
// PerformanceStatistics.
var (
reIoregModel = regexp.MustCompile(`"model"\s*=\s*"([^"]+)"`)
reIoregCoreCount = regexp.MustCompile(`"gpu-core-count"\s*=\s*(\d+)`)
reIoregUtil = regexp.MustCompile(`"Device Utilization %"=(\d+)`)
reIoregMemUsed = regexp.MustCompile(`"In use system memory"=(\d+)`)
)
// ParseIoregOutput parses `ioreg -r -c IOGPU -d 1 -f` output into a GpuStat for
// the Apple Silicon integrated GPU. This is a fallback for when mactop is not
// installed: utilization and used memory are available, but power, temperature,
// and fan speed are not exposed by ioreg. memTotalMB is the unified memory size
// supplied by the caller, since Apple Silicon shares memory between CPU and GPU.
// Returns nil if no GPU device is found in the output.
func ParseIoregOutput(out []byte, memTotalMB int) *GpuStat {
utilMatch := reIoregUtil.FindSubmatch(out)
memMatch := reIoregMemUsed.FindSubmatch(out)
if utilMatch == nil && memMatch == nil {
return nil
}
var gpuUtil float64
if utilMatch != nil {
gpuUtil, _ = strconv.ParseFloat(string(utilMatch[1]), 64)
}
const toMB = 1024 * 1024
var memUsedMB int
if memMatch != nil {
memUsedBytes, _ := strconv.ParseInt(string(memMatch[1]), 10, 64)
memUsedMB = int(memUsedBytes / toMB)
}
var memUtil float64
if memTotalMB > 0 {
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
}
name := "Apple GPU"
if m := reIoregModel.FindSubmatch(out); m != nil {
name = string(m[1])
}
if m := reIoregCoreCount.FindSubmatch(out); m != nil {
if cores, err := strconv.Atoi(string(m[1])); err == nil && cores > 0 {
name = fmt.Sprintf("%s (%d-core GPU)", name, cores)
}
}
return &GpuStat{
Timestamp: time.Now(),
ID: 0,
Name: name,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtil,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
}
}
// ParseMactopLine parses a single line of mactop headless JSON output into a
// GpuStat for the Apple Silicon integrated GPU. Returns nil if the line cannot
// be parsed.
func ParseMactopLine(line string) *GpuStat {
line = strings.TrimSpace(line)
if line == "" {
return nil
}
var out mactopOutput
if err := json.Unmarshal([]byte(line), &out); err != nil {
return nil
}
const toMB = 1024 * 1024
memUsedMB := int(out.Memory.Used / toMB)
memTotalMB := int(out.Memory.Total / toMB)
var memUtil float64
if memTotalMB > 0 {
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
}
name := out.SystemInfo.Name
if name == "" {
name = "Apple GPU"
}
if out.SystemInfo.GPUCoreCount > 0 {
name = fmt.Sprintf("%s (%d-core GPU)", name, out.SystemInfo.GPUCoreCount)
}
// Unified memory has no dedicated VRAM sensor; use the memory temperature
// group when mactop exposes it.
var vramTempC int
for _, t := range out.Temperatures {
if strings.EqualFold(t.Group, "Memory") {
vramTempC = int(math.Round(t.Avg))
break
}
}
// Average fan load across all fans as a percentage of their RPM range.
var fanSpeed float64
var fanCount int
for _, f := range out.Fans {
if f.MaxRPM > f.MinRPM {
pct := float64(f.RPM-f.MinRPM) / float64(f.MaxRPM-f.MinRPM) * 100
if pct < 0 {
pct = 0
}
fanSpeed += pct
fanCount++
}
}
if fanCount > 0 {
fanSpeed /= float64(fanCount)
}
return &GpuStat{
Timestamp: time.Now(),
ID: 0,
Name: name,
TempC: int(math.Round(out.SocMetrics.GPUTemp)),
VramTempC: vramTempC,
GpuUtilPct: out.GPUUsage,
MemUtilPct: memUtil,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeed,
PowerDrawW: out.SocMetrics.GPUPower,
}
}
+206
View File
@@ -0,0 +1,206 @@
package perf
import (
"context"
"errors"
"sync"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/ring"
)
var (
ErrNotImplemented = errors.New("Not Implemented")
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
)
type Monitor struct {
mutex sync.RWMutex
log *logmon.Monitor
conf config.PerformanceConfig
sysRing ring.Buffer[SysStat]
gpuRing ring.Buffer[[]GpuStat]
stopCtx context.Context
stopCancel context.CancelFunc
sysListeners map[chan SysStat]struct{}
gpuListeners map[chan []GpuStat]struct{}
}
func ringCapacity(c config.PerformanceConfig) int {
n := int(time.Hour / c.Every)
if n < 1 {
n = 1
}
return n
}
func New(c config.PerformanceConfig, logger *logmon.Monitor) (*Monitor, error) {
if c.Every < 100*time.Millisecond {
c.Every = 100 * time.Millisecond
}
if logger == nil {
return nil, errors.New("logger is required")
}
capacity := ringCapacity(c)
return &Monitor{
conf: c,
log: logger,
sysRing: ring.NewBuffer[SysStat](capacity),
gpuRing: ring.NewBuffer[[]GpuStat](capacity),
sysListeners: make(map[chan SysStat]struct{}),
gpuListeners: make(map[chan []GpuStat]struct{}),
}, nil
}
func (m *Monitor) Stop() {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.stopCancel == nil {
return
}
m.stopCancel()
m.stopCancel = nil
}
// UpdateConfig updates the monitor configuration and restarts if changed.
func (m *Monitor) UpdateConfig(newConf config.PerformanceConfig) {
m.mutex.RLock()
changed := m.conf != newConf
m.mutex.RUnlock()
if !changed {
return
}
m.Stop()
m.mutex.Lock()
m.conf = newConf
capacity := ringCapacity(newConf)
m.sysRing = ring.NewBuffer[SysStat](capacity)
m.gpuRing = ring.NewBuffer[[]GpuStat](capacity)
m.mutex.Unlock()
if !newConf.Disabled {
m.Start()
}
}
// Subscribe returns channels to listen to system and GPU stats.
func (m *Monitor) Subscribe() (chan SysStat, chan []GpuStat, func()) {
m.mutex.Lock()
defer m.mutex.Unlock()
sysChan := make(chan SysStat, 1)
gpuChan := make(chan []GpuStat, 1)
m.sysListeners[sysChan] = struct{}{}
m.gpuListeners[gpuChan] = struct{}{}
unsub := func() {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.sysListeners, sysChan)
delete(m.gpuListeners, gpuChan)
}
return sysChan, gpuChan, unsub
}
func (m *Monitor) Start() {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.stopCancel != nil {
return
}
m.stopCtx, m.stopCancel = context.WithCancel(context.Background())
go func() {
tick := time.NewTicker(m.conf.Every)
defer tick.Stop()
for {
select {
case <-m.stopCtx.Done():
return
case <-tick.C:
s, err := ReadSysStats()
if err != nil {
if err != ErrNotImplemented {
m.log.Errorf("failed to read sys stats: %s", err.Error())
}
continue
}
m.mutex.Lock()
m.sysRing.Push(s)
for l := range m.sysListeners {
select {
case l <- s:
default:
}
}
m.mutex.Unlock()
}
}
}()
go func() {
gpuCh, err := getGpuStats(m.stopCtx, m.conf.Every, m.log)
if err != nil {
if errors.Is(err, ErrNotImplemented) || errors.Is(err, ErrNoGpuTool) {
m.log.Infof("GPU monitoring not available: %s", err.Error())
} else {
m.log.Errorf("failed to initialize GPU monitoring: %s", err.Error())
}
return
}
for {
select {
case <-m.stopCtx.Done():
return
case g, ok := <-gpuCh:
if !ok {
m.log.Errorf("failed reading from gpuCh - stopping read goroutine")
return
}
m.mutex.Lock()
m.gpuRing.Push(g)
for l := range m.gpuListeners {
select {
case l <- g:
default:
}
}
m.mutex.Unlock()
}
}
}()
}
// Current returns a copy of the current log of system and GPU stats.
func (m *Monitor) Current() ([]SysStat, []GpuStat) {
m.mutex.RLock()
defer m.mutex.RUnlock()
sysStats := m.sysRing.Slice()
snapshots := m.gpuRing.Slice()
var gpuStats []GpuStat
for _, snapshot := range snapshots {
gpuStats = append(gpuStats, snapshot...)
}
return sysStats, gpuStats
}
func ReadSysStats() (SysStat, error) {
return readSysStats()
}
func GetGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
return getGpuStats(ctx, every, logger)
}
+208
View File
@@ -0,0 +1,208 @@
package perf
import (
"bufio"
"context"
"fmt"
"os/exec"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/load"
"github.com/shirou/gopsutil/v4/mem"
)
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if ch, err := tryMactop(ctx, every, logger); err == nil {
logger.Info("using mactop for GPU monitoring")
return ch, nil
} else {
logger.Debugf("mactop: %s", err.Error())
}
if ch, err := tryIoreg(ctx, every, logger); err == nil {
logger.Info("using ioreg for GPU monitoring")
return ch, nil
} else {
logger.Debugf("ioreg: %s", err.Error())
}
return nil, ErrNoGpuTool
}
// tryIoreg polls `ioreg -r -c IOGPU -d 1 -f` for Apple Silicon GPU stats. It is
// a fallback for when mactop is not installed. ioreg exposes GPU utilization and
// used memory but not power, temperature, or fan speed.
func tryIoreg(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("ioreg"); err != nil {
return nil, ErrNoGpuTool
}
// Verify ioreg actually reports a GPU device before committing to it, so we
// can fall through to ErrNoGpuTool otherwise.
if stat := sampleIoreg(ctx); stat == nil {
return nil, fmt.Errorf("ioreg reported no GPU device")
}
if every < time.Second {
every = time.Second
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
stat := sampleIoreg(ctx)
if stat == nil {
continue
}
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
}()
return ch, nil
}
// sampleIoreg runs ioreg once and parses a single GpuStat, or returns nil.
func sampleIoreg(ctx context.Context) *GpuStat {
out, err := exec.CommandContext(ctx, "ioreg", "-r", "-c", "IOGPU", "-d", "1", "-f").Output()
if err != nil {
return nil
}
var memTotalMB int
if vmStat, err := mem.VirtualMemory(); err == nil {
memTotalMB = int(vmStat.Total / (1024 * 1024))
}
return ParseIoregOutput(out, memTotalMB)
}
// overlayIoregMem replaces a GpuStat's memory fields with the GPU-attributed
// unified memory reported by ioreg. mactop only exposes whole-system memory, so
// without this the mactop and ioreg backends would report different memory
// semantics. It is a no-op when ioreg is unavailable or reports no GPU memory,
// leaving the mactop-supplied values in place.
func overlayIoregMem(ctx context.Context, stat *GpuStat) {
ioStat := sampleIoreg(ctx)
if ioStat == nil {
return
}
stat.MemUsedMB = ioStat.MemUsedMB
stat.MemTotalMB = ioStat.MemTotalMB
stat.MemUtilPct = ioStat.MemUtilPct
}
// tryMactop streams Apple Silicon GPU stats from mactop's headless mode.
// See https://github.com/metaspartan/mactop. mactop emits one JSON object per
// sample to stdout, which we parse into GpuStat.
func tryMactop(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("mactop"); err != nil {
return nil, ErrNoGpuTool
}
// mactop samples power over the interval, so give it at least a second.
intervalMs := int(every.Milliseconds())
if intervalMs < 1000 {
intervalMs = 1000
}
cmd := exec.CommandContext(ctx, "mactop",
"--headless",
"--format", "json",
"--interval", fmt.Sprintf("%d", intervalMs),
)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("mactop stdout pipe failed: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("mactop start failed: %w", err)
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
scanner := bufio.NewScanner(stdout)
// mactop's JSON objects can be large; allow generous line lengths.
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
stat := ParseMactopLine(line)
if stat != nil {
// mactop only reports whole-system memory; overlay ioreg's
// GPU-attributed unified memory so both backends are consistent.
overlayIoregMem(ctx, stat)
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
cmd.Wait()
}()
return ch, nil
}
func readSysStats() (SysStat, error) {
cpuPcts, err := cpu.Percent(0, true)
if err != nil {
return SysStat{}, err
}
vmStat, err := mem.VirtualMemory()
if err != nil {
return SysStat{}, err
}
const toMB = 1024 * 1024
var swapTotalMB, swapUsedMB int
if swapStat, err := mem.SwapMemory(); err == nil {
swapTotalMB = int(swapStat.Total / toMB)
swapUsedMB = int(swapStat.Used / toMB)
}
var loadAvg1, loadAvg5, loadAvg15 float64
if loadStat, err := load.Avg(); err == nil {
loadAvg1 = loadStat.Load1
loadAvg5 = loadStat.Load5
loadAvg15 = loadStat.Load15
}
return SysStat{
Timestamp: time.Now(),
CpuUtilPerCore: cpuPcts,
MemTotalMB: int(vmStat.Total / toMB),
MemUsedMB: int(vmStat.Used / toMB),
MemFreeMB: int(vmStat.Free / toMB),
SwapTotalMB: swapTotalMB,
SwapUsedMB: swapUsedMB,
LoadAvg1: loadAvg1,
LoadAvg5: loadAvg5,
LoadAvg15: loadAvg15,
}, nil
}
+313
View File
@@ -0,0 +1,313 @@
package perf
import (
"io"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestLogger() *logmon.Monitor {
return logmon.NewWriter(io.Discard)
}
func TestNew_DefaultConfig(t *testing.T) {
logger := newTestLogger()
m, err := New(config.PerformanceConfig{}, logger)
require.NoError(t, err)
require.NotNil(t, m)
assert.Equal(t, 100*time.Millisecond, m.conf.Every)
}
func TestNew_CustomConfig(t *testing.T) {
logger := newTestLogger()
cfg := config.PerformanceConfig{
Every: 500 * time.Millisecond,
}
m, err := New(cfg, logger)
require.NoError(t, err)
assert.Equal(t, 500*time.Millisecond, m.conf.Every)
}
func TestNew_NilLogger(t *testing.T) {
m, err := New(config.PerformanceConfig{}, nil)
assert.Error(t, err)
assert.Nil(t, m)
}
func TestNew_BelowMinimumConfig(t *testing.T) {
logger := newTestLogger()
cfg := config.PerformanceConfig{
Every: 1 * time.Millisecond,
}
m, err := New(cfg, logger)
require.NoError(t, err)
assert.Equal(t, 100*time.Millisecond, m.conf.Every)
}
func TestSubscribe_ReturnsChannels(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
sysCh, gpuCh, unsub := m.Subscribe()
defer unsub()
assert.NotNil(t, sysCh)
assert.NotNil(t, gpuCh)
assert.NotNil(t, unsub)
}
func TestSubscribe_UnsubscribeRemovesListeners(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
_, _, unsub := m.Subscribe()
m.mutex.RLock()
assert.Len(t, m.sysListeners, 1)
assert.Len(t, m.gpuListeners, 1)
m.mutex.RUnlock()
unsub()
m.mutex.RLock()
assert.Len(t, m.sysListeners, 0)
assert.Len(t, m.gpuListeners, 0)
m.mutex.RUnlock()
}
func TestSubscribe_MultipleSubscriptions(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
sysCh1, gpuCh1, unsub1 := m.Subscribe()
sysCh2, gpuCh2, unsub2 := m.Subscribe()
defer unsub1()
defer unsub2()
assert.NotEqual(t, sysCh1, sysCh2)
assert.NotEqual(t, gpuCh1, gpuCh2)
m.mutex.RLock()
assert.Len(t, m.sysListeners, 2)
assert.Len(t, m.gpuListeners, 2)
m.mutex.RUnlock()
}
func TestCurrent_EmptyByDefault(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
sysStats, gpuStats := m.Current()
assert.Empty(t, sysStats)
assert.Empty(t, gpuStats)
}
func TestCurrent_ReturnsCopies(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
now := time.Now()
m.sysRing.Push(SysStat{Timestamp: now, MemTotalMB: 1024})
m.gpuRing.Push([]GpuStat{{Timestamp: now, ID: 0, Name: "gpu0"}})
sysStats, gpuStats := m.Current()
assert.Len(t, sysStats, 1)
assert.Len(t, gpuStats, 1)
assert.Equal(t, 1024, sysStats[0].MemTotalMB)
assert.Equal(t, "gpu0", gpuStats[0].Name)
// modifying the returned slice should not affect the original
sysStats[0].MemTotalMB = 999
original, _ := m.Current()
assert.Equal(t, 1024, original[0].MemTotalMB)
}
func TestStart_CollectsSysStats(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
m, err := New(config.PerformanceConfig{Every: 100 * time.Millisecond}, newTestLogger())
require.NoError(t, err)
m.Start()
time.Sleep(350 * time.Millisecond)
m.Stop()
sysStats, _ := m.Current()
assert.NotEmpty(t, sysStats, "expected sys stats to be collected")
}
func TestStart_StopStopsGoroutines(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
m, err := New(config.PerformanceConfig{Every: 100 * time.Millisecond}, newTestLogger())
require.NoError(t, err)
m.Start()
if m.stopCancel == nil {
t.Error("stopCancel should not be nil after Start()")
}
m.Stop()
if m.stopCancel != nil {
t.Error("stopCancel should be nil after Stop()")
}
}
func TestStart_SubscriberReceivesStats(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
m, err := New(config.PerformanceConfig{Every: 100 * time.Millisecond}, newTestLogger())
require.NoError(t, err)
sysCh, _, unsub := m.Subscribe()
defer unsub()
m.Start()
defer m.Stop()
select {
case s := <-sysCh:
assert.False(t, s.Timestamp.IsZero())
assert.NotEmpty(t, s.CpuUtilPerCore)
case <-time.After(500 * time.Millisecond):
t.Fatal("timed out waiting for sys stats")
}
}
func TestReadSysStats(t *testing.T) {
s, err := ReadSysStats()
require.NoError(t, err)
assert.False(t, s.Timestamp.IsZero())
assert.NotEmpty(t, s.CpuUtilPerCore)
assert.Greater(t, s.MemTotalMB, 0)
}
func TestCurrent_ConcurrentAccess(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
m.sysRing.Push(SysStat{Timestamp: time.Now(), MemTotalMB: 1024})
m.gpuRing.Push([]GpuStat{{Timestamp: time.Now(), ID: 0}})
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
sys, gpu := m.Current()
assert.Len(t, sys, 1)
assert.Len(t, gpu, 1)
}()
}
wg.Wait()
}
func TestParseNvidiaSmiLine_ValidLine(t *testing.T) {
line := "0, NVIDIA GeForce RTX 3080, GPU-12345678-1234-1234-1234-123456789abc, 65, 80, 8192, 10240, 75, 250"
stat := ParseNvidiaSmiLine(line)
require.NotNil(t, stat)
assert.Equal(t, 0, stat.ID)
assert.Equal(t, "NVIDIA GeForce RTX 3080", stat.Name)
assert.Equal(t, "GPU-12345678-1234-1234-1234-123456789abc", stat.UUID)
assert.Equal(t, 65, stat.TempC)
assert.Equal(t, 80.0, stat.GpuUtilPct)
assert.Equal(t, 8192, stat.MemUsedMB)
assert.Equal(t, 10240, stat.MemTotalMB)
assert.Equal(t, 75.0, stat.FanSpeedPct)
assert.Equal(t, 250.0, stat.PowerDrawW)
assert.InDelta(t, 80.0, stat.MemUtilPct, 0.01)
}
func TestParseNvidiaSmiLine_ShortLine(t *testing.T) {
line := "0, NVIDIA GPU, GPU-123"
stat := ParseNvidiaSmiLine(line)
assert.Nil(t, stat)
}
func TestParseNvidiaSmiLine_MissingFields(t *testing.T) {
line := "0, NVIDIA GPU, GPU-123, 65, 80, 8192, 10240, 75"
stat := ParseNvidiaSmiLine(line)
assert.Nil(t, stat)
}
func TestParseNvidiaSmiLine_ZeroMemoryTotal(t *testing.T) {
line := "0, NVIDIA GPU, GPU-123, 65, 80, 0, 0, 75, 250"
stat := ParseNvidiaSmiLine(line)
require.NotNil(t, stat)
assert.Equal(t, 0.0, stat.MemUtilPct)
}
const ioregSample = `+-o AGXAcceleratorG13X <class AGXAcceleratorG13X, id 0x1000009a1, registered, matched, active, busy 0 (39191 ms), retain 108>
{
"model" = "Apple M1 Pro"
"gpu-core-count" = 16
"PerformanceStatistics" = {"In use system memory (driver)"=0,"Alloc system memory"=14511046656,"Tiler Utilization %"=34,"recoveryCount"=0,"Renderer Utilization %"=34,"Device Utilization %"=34,"In use system memory"=7688503296}
"IOClass" = "AGXAcceleratorG13X"
}`
func TestParseIoregOutput_ValidOutput(t *testing.T) {
const memTotalMB = 32768
stat := ParseIoregOutput([]byte(ioregSample), memTotalMB)
require.NotNil(t, stat)
assert.Equal(t, 0, stat.ID)
assert.Equal(t, "Apple M1 Pro (16-core GPU)", stat.Name)
assert.Equal(t, 34.0, stat.GpuUtilPct)
assert.Equal(t, 7688503296/(1024*1024), stat.MemUsedMB)
assert.Equal(t, memTotalMB, stat.MemTotalMB)
assert.InDelta(t, float64(stat.MemUsedMB)/memTotalMB*100, stat.MemUtilPct, 0.01)
// Not exposed by ioreg.
assert.Equal(t, 0, stat.TempC)
assert.Equal(t, 0.0, stat.PowerDrawW)
assert.Equal(t, 0.0, stat.FanSpeedPct)
}
func TestParseIoregOutput_NoGpuDevice(t *testing.T) {
stat := ParseIoregOutput([]byte("no gpu here"), 32768)
assert.Nil(t, stat)
}
func TestParseIoregOutput_ZeroMemTotal(t *testing.T) {
stat := ParseIoregOutput([]byte(ioregSample), 0)
require.NotNil(t, stat)
assert.Equal(t, 0.0, stat.MemUtilPct)
}
func TestParseIoregOutput_MissingModel(t *testing.T) {
const out = `"Device Utilization %"=50,"In use system memory"=1048576`
stat := ParseIoregOutput([]byte(out), 1024)
require.NotNil(t, stat)
assert.Equal(t, "Apple GPU", stat.Name)
assert.Equal(t, 50.0, stat.GpuUtilPct)
assert.Equal(t, 1, stat.MemUsedMB)
}
+584
View File
@@ -0,0 +1,584 @@
//go:build unix && !darwin
package perf
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/load"
"github.com/shirou/gopsutil/v4/mem"
psnet "github.com/shirou/gopsutil/v4/net"
)
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if ch, err := tryLACT(ctx, every, logger); err == nil {
logger.Info("using LACT for GPU monitoring")
return ch, nil
} else {
logger.Debugf("LACT: %s", err.Error())
}
if ch, err := tryNvidiaSmi(ctx, every, logger); err == nil {
logger.Info("using nvidia-smi for GPU monitoring")
return ch, nil
} else {
logger.Debugf("nvidia-smi: %s", err.Error())
}
if ch, err := tryRocmSmi(ctx, every, logger); err == nil {
logger.Info("using rocm-smi for GPU monitoring")
return ch, nil
} else {
logger.Debugf("rocm-smi: %s", err.Error())
}
if ch, err := trySysfs(ctx, every, logger); err == nil {
logger.Info("using sysfs for GPU monitoring")
return ch, nil
} else {
logger.Debugf("sysfs: %s", err.Error())
}
return nil, ErrNoGpuTool
}
func tryLACT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
socketPath := lactSocketPath()
if socketPath == "" {
return nil, ErrNoGpuTool
}
conn, err := net.DialTimeout("unix", socketPath, 2*time.Second)
if err != nil {
return nil, fmt.Errorf("cannot connect to LACT socket: %w", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
devices, err := lactListDevices(conn)
if err != nil {
return nil, fmt.Errorf("LACT ListDevices failed: %w", err)
}
if len(devices) == 0 {
return nil, fmt.Errorf("LACT returned no devices")
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
socketPath := lactSocketPath()
if socketPath == "" {
continue
}
conn, err := net.DialTimeout("unix", socketPath, 2*time.Second)
if err != nil {
continue
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
devices, err := lactListDevices(conn)
if err != nil {
conn.Close()
continue
}
stats := make([]GpuStat, 0, len(devices))
for i, d := range devices {
stat, err := lactGetDeviceStats(conn, d.ID, d.Name, i)
if err != nil {
continue
}
if stat.MemTotalMB == 0 {
continue
}
stats = append(stats, stat)
}
conn.Close()
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("nvidia-smi"); err != nil {
return nil, ErrNoGpuTool
}
sec := int(every.Seconds())
if sec < 1 {
sec = 1
}
cmd := exec.CommandContext(ctx, "nvidia-smi",
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
"--format=csv,noheader,nounits",
"--loop", fmt.Sprintf("%d", sec),
)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("nvidia-smi stdout pipe failed: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("nvidia-smi start failed: %w", err)
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
stat := ParseNvidiaSmiLine(line)
if stat != nil {
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
cmd.Wait()
}()
return ch, nil
}
func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("rocm-smi"); err != nil {
return nil, ErrNoGpuTool
}
if every < time.Second {
every = time.Second
}
const pollTimeout = 5 * time.Second
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
pollCtx, cancel := context.WithTimeout(ctx, pollTimeout)
cmd := exec.CommandContext(pollCtx, "rocm-smi", "-i", "-P", "-t", "-f", "-u", "--showmemuse", "--showmeminfo", "vram", "--showproductname", "--csv")
out, err := cmd.Output()
timedOut := pollCtx.Err() == context.DeadlineExceeded
cancel()
if err != nil {
if timedOut {
logger.Debug("rocm-smi timed out")
}
continue
}
stats := make([]GpuStat, 0)
scanner := bufio.NewScanner(strings.NewReader(string(out)))
var header string
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
if strings.HasPrefix(line, "device,") {
header = line
continue
}
stat := parseRocmSmiLine(header, line)
if stat != nil {
stats = append(stats, *stat)
}
}
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
func parseRocmSmiLine(header string, line string) *GpuStat {
if header == "" || line == "" {
return nil
}
labels := strings.Split(header, ",")
fields := strings.Split(line, ",")
if len(labels) != len(fields) {
return nil
}
result := &GpuStat{
Timestamp: time.Now(),
ID: -1,
}
var device string
var deviceName string
var cardSeries string
var gfxVersion string
const toMB = 1024 * 1024
for i, col := range labels {
val := strings.TrimSpace(fields[i])
switch col {
case "device":
device = val
id, err := strconv.Atoi(strings.TrimPrefix(val, "card"))
if err != nil {
return nil
}
result.ID = id
case "Device Name":
deviceName = val
case "GUID":
result.UUID = val
case "Temperature (Sensor edge) (C)":
tempC, _ := strconv.ParseFloat(val, 64)
result.TempC = int(tempC)
case "Temperature (Sensor memory) (C)":
vramTempC, _ := strconv.ParseFloat(val, 64)
result.VramTempC = int(vramTempC)
case "Fan speed (%)":
fanSpeed, _ := strconv.ParseFloat(val, 64)
result.FanSpeedPct = fanSpeed
case "Current Socket Graphics Package Power (W)":
fallthrough
case "Average Graphics Package Power (W)":
powerDraw, _ := strconv.ParseFloat(val, 64)
result.PowerDrawW = powerDraw
case "GPU use (%)":
gpuUtil, _ := strconv.ParseFloat(val, 64)
result.GpuUtilPct = gpuUtil
case "GPU Memory Allocated (VRAM%)":
memUtil, _ := strconv.ParseFloat(val, 64)
result.MemUtilPct = memUtil
case "VRAM Total Memory (B)":
memTotal, _ := strconv.ParseUint(val, 10, 64)
result.MemTotalMB = int(memTotal / toMB)
case "VRAM Total Used Memory (B)":
memUsed, _ := strconv.ParseUint(val, 10, 64)
result.MemUsedMB = int(memUsed / toMB)
case "Card Series":
cardSeries = val
case "GFX Version":
gfxVersion = val
}
}
if result.ID == -1 {
return nil
}
name := device
if cardSeries != "" && cardSeries != "N/A" {
name = cardSeries + " " + device + " (" + gfxVersion + ")"
} else if deviceName != "" && deviceName != "N/A" {
name = deviceName + " " + device + " (" + gfxVersion + ")"
}
result.Name = name
return result
}
func trySysfs(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
return nil, ErrNotImplemented
}
func lactSocketPath() string {
if p := os.Getenv("LACT_DAEMON_SOCKET_PATH"); p != "" {
if _, err := os.Stat(p); err == nil {
return p
}
}
rootPath := "/run/lactd.sock"
if _, err := os.Stat(rootPath); err == nil {
return rootPath
}
u, err := user.Current()
if err != nil {
return ""
}
userPath := filepath.Join("/run/user", u.Uid, "lactd.sock")
if _, err := os.Stat(userPath); err == nil {
return userPath
}
return ""
}
type lactRequest struct {
Command string `json:"command"`
Args interface{} `json:"args,omitempty"`
}
type lactResponse struct {
Status string `json:"status"`
Data json.RawMessage `json:"data"`
}
type lactDeviceEntry struct {
ID string `json:"id"`
Name string `json:"name"`
}
type lactDeviceStats struct {
Fan struct {
PwmCurrent *uint8 `json:"pwm_current"`
} `json:"fan"`
Vram struct {
Total *uint64 `json:"total"`
Used *uint64 `json:"used"`
} `json:"vram"`
Power struct {
Average *float64 `json:"average"`
Current *float64 `json:"current"`
} `json:"power"`
Temps map[string]lactTempEntry `json:"temps"`
BusyPercent *uint8 `json:"busy_percent"`
}
type lactTempEntry struct {
Current *float64 `json:"current"`
}
func lactSendRequest(conn net.Conn, req lactRequest) (json.RawMessage, error) {
data, err := json.Marshal(req)
if err != nil {
return nil, err
}
data = append(data, '\n')
if _, err := conn.Write(data); err != nil {
return nil, err
}
reader := bufio.NewReader(conn)
line, err := reader.ReadBytes('\n')
if err != nil {
return nil, err
}
var resp lactResponse
if err := json.Unmarshal(line, &resp); err != nil {
return nil, err
}
if resp.Status != "ok" {
return nil, fmt.Errorf("LACT error: %s", string(resp.Data))
}
return resp.Data, nil
}
func lactListDevices(conn net.Conn) ([]lactDeviceEntry, error) {
data, err := lactSendRequest(conn, lactRequest{Command: "list_devices"})
if err != nil {
return nil, err
}
var devices []lactDeviceEntry
if err := json.Unmarshal(data, &devices); err != nil {
return nil, err
}
return devices, nil
}
func lactGetDeviceStats(conn net.Conn, id string, name string, index int) (GpuStat, error) {
data, err := lactSendRequest(conn, lactRequest{
Command: "device_stats",
Args: struct {
ID string `json:"id"`
}{ID: id},
})
if err != nil {
return GpuStat{}, err
}
var stats lactDeviceStats
if err := json.Unmarshal(data, &stats); err != nil {
return GpuStat{}, err
}
var memUsedMB, memTotalMB int
if stats.Vram.Used != nil {
memUsedMB = int(*stats.Vram.Used / 1024 / 1024)
}
if stats.Vram.Total != nil {
memTotalMB = int(*stats.Vram.Total / 1024 / 1024)
}
var memUtil float64
if memTotalMB > 0 {
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
}
var gpuUtil float64
if stats.BusyPercent != nil {
gpuUtil = float64(*stats.BusyPercent)
}
var fanSpeed float64
if stats.Fan.PwmCurrent != nil {
fanSpeed = float64(*stats.Fan.PwmCurrent) / 255.0 * 100.0
}
var powerDraw float64
if stats.Power.Average != nil && *stats.Power.Average > 0 {
powerDraw = *stats.Power.Average
} else if stats.Power.Current != nil {
powerDraw = *stats.Power.Current
}
var tempC int
if t, ok := stats.Temps["edge"]; ok && t.Current != nil {
tempC = int(*t.Current)
} else if t, ok := stats.Temps["junction"]; ok && t.Current != nil {
tempC = int(*t.Current)
} else {
for _, t := range stats.Temps {
if t.Current != nil {
tempC = int(*t.Current)
break
}
}
}
var vramTempC int
// nvidia uses "VRAM", amd "mem"
for _, key := range []string{"mem", "VRAM"} {
if t, ok := stats.Temps[key]; ok && t.Current != nil && *t.Current > 0 {
vramTempC = int(*t.Current)
break
}
}
return GpuStat{
Timestamp: time.Now(),
ID: index,
Name: name,
UUID: id,
TempC: tempC,
VramTempC: vramTempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtil,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeed,
PowerDrawW: powerDraw,
}, nil
}
func readSysfs() ([]GpuStat, error) {
return nil, ErrNotImplemented
}
func readSysStats() (SysStat, error) {
cpuPcts, err := cpu.Percent(0, true)
if err != nil {
return SysStat{}, err
}
vmStat, err := mem.VirtualMemory()
if err != nil {
return SysStat{}, err
}
const toMB = 1024 * 1024
var swapTotalMB, swapUsedMB int
if swapStat, err := mem.SwapMemory(); err == nil {
swapTotalMB = int(swapStat.Total / toMB)
swapUsedMB = int(swapStat.Used / toMB)
}
var loadAvg1, loadAvg5, loadAvg15 float64
if loadStat, err := load.Avg(); err == nil {
loadAvg1 = loadStat.Load1
loadAvg5 = loadStat.Load5
loadAvg15 = loadStat.Load15
}
netIO := make([]NetIOStat, 0)
if ioCounters, err := psnet.IOCounters(true); err == nil {
for _, ioc := range ioCounters {
if ioc.Name == "lo" {
continue
}
netIO = append(netIO, NetIOStat{
Name: ioc.Name,
BytesRecv: ioc.BytesRecv,
BytesSent: ioc.BytesSent,
})
}
}
return SysStat{
Timestamp: time.Now(),
CpuUtilPerCore: cpuPcts,
MemTotalMB: int(vmStat.Total / toMB),
MemUsedMB: int(vmStat.Used / toMB),
MemFreeMB: int(vmStat.Free / toMB),
SwapTotalMB: swapTotalMB,
SwapUsedMB: swapUsedMB,
LoadAvg1: loadAvg1,
LoadAvg5: loadAvg5,
LoadAvg15: loadAvg15,
NetIO: netIO,
}, nil
}
+114
View File
@@ -0,0 +1,114 @@
package perf
import (
"bufio"
"context"
"fmt"
"os/exec"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
"github.com/shirou/gopsutil/v4/net"
)
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if ch, err := tryNvidiaSmiWindows(ctx, every, logger); err == nil {
logger.Info("using nvidia-smi for GPU monitoring")
return ch, nil
} else {
logger.Debugf("nvidia-smi: %s", err.Error())
}
return nil, ErrNoGpuTool
}
// tryNvidiaSmiWindows starts nvidia-smi in loop mode on Windows and returns
// a channel receiving GPU stat snapshots. Returns ErrNoGpuTool if nvidia-smi
// is not available.
func tryNvidiaSmiWindows(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("nvidia-smi"); err != nil {
return nil, ErrNoGpuTool
}
sec := int(every.Seconds())
if sec < 1 {
sec = 1
}
cmd := exec.CommandContext(ctx, "nvidia-smi",
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
"--format=csv,noheader,nounits",
"--loop", fmt.Sprintf("%d", sec),
)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("nvidia-smi stdout pipe failed: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("nvidia-smi start failed: %w", err)
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
stat := ParseNvidiaSmiLine(line)
if stat != nil {
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
cmd.Wait()
}()
return ch, nil
}
func readSysStats() (SysStat, error) {
cpuPcts, err := cpu.Percent(0, true)
if err != nil {
return SysStat{}, err
}
vmStat, err := mem.VirtualMemory()
if err != nil {
return SysStat{}, err
}
const toMB = 1024 * 1024
netIO := make([]NetIOStat, 0)
if ioCounters, err := net.IOCounters(true); err == nil {
for _, ioc := range ioCounters {
netIO = append(netIO, NetIOStat{
Name: ioc.Name,
BytesRecv: ioc.BytesRecv,
BytesSent: ioc.BytesSent,
})
}
}
return SysStat{
Timestamp: time.Now(),
CpuUtilPerCore: cpuPcts,
MemTotalMB: int(vmStat.Total / toMB),
MemUsedMB: int(vmStat.Used / toMB),
MemFreeMB: int(vmStat.Free / toMB),
NetIO: netIO,
}, nil
}
+129
View File
@@ -0,0 +1,129 @@
package perf
import (
"fmt"
"net/http"
"sort"
"strings"
)
const mbToBytes = int64(1024 * 1024)
// MetricsHandler returns an http.HandlerFunc serving Prometheus text format metrics
// with the most recent system and GPU stats.
func (m *Monitor) MetricsHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sysStats, gpuStats := m.Current()
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
if len(sysStats) > 0 {
writeSysMetrics(w, sysStats[len(sysStats)-1])
}
if len(gpuStats) > 0 {
writeGpuMetrics(w, latestPerGPU(gpuStats))
}
}
}
func writeSysMetrics(w http.ResponseWriter, s SysStat) {
fmt.Fprintf(w, "# HELP llamaswap_cpu_util_percent CPU utilization per core (0-100)\n")
fmt.Fprintf(w, "# TYPE llamaswap_cpu_util_percent gauge\n")
for i, pct := range s.CpuUtilPerCore {
fmt.Fprintf(w, "llamaswap_cpu_util_percent{core=\"%d\"} %g\n", i, pct)
}
fmt.Fprintf(w, "# HELP llamaswap_memory_total_bytes Total memory in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_memory_total_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_memory_total_bytes %d\n", int64(s.MemTotalMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_memory_used_bytes Used memory in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_memory_used_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_memory_used_bytes %d\n", int64(s.MemUsedMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_memory_free_bytes Free memory in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_memory_free_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_memory_free_bytes %d\n", int64(s.MemFreeMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_swap_total_bytes Total swap in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_swap_total_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_swap_total_bytes %d\n", int64(s.SwapTotalMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_swap_used_bytes Used swap in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_swap_used_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_swap_used_bytes %d\n", int64(s.SwapUsedMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_load_average Load average\n")
fmt.Fprintf(w, "# TYPE llamaswap_load_average gauge\n")
fmt.Fprintf(w, "llamaswap_load_average{interval=\"1m\"} %g\n", s.LoadAvg1)
fmt.Fprintf(w, "llamaswap_load_average{interval=\"5m\"} %g\n", s.LoadAvg5)
fmt.Fprintf(w, "llamaswap_load_average{interval=\"15m\"} %g\n", s.LoadAvg15)
if len(s.NetIO) > 0 {
fmt.Fprintf(w, "# HELP llamaswap_network_bytes_total Total network bytes transferred\n")
fmt.Fprintf(w, "# TYPE llamaswap_network_bytes_total counter\n")
for _, io := range s.NetIO {
iface := sanitizeLabel(io.Name)
fmt.Fprintf(w, "llamaswap_network_bytes_total{interface=\"%s\",direction=\"recv\"} %d\n", iface, io.BytesRecv)
fmt.Fprintf(w, "llamaswap_network_bytes_total{interface=\"%s\",direction=\"sent\"} %d\n", iface, io.BytesSent)
}
}
}
func writeGpuMetrics(w http.ResponseWriter, gpus []GpuStat) {
if len(gpus) == 0 {
return
}
type gpuMetric struct {
help string
name string
value func(GpuStat) float64
}
metrics := []gpuMetric{
{"GPU temperature in Celsius", "llamaswap_gpu_temperature_celsius", func(g GpuStat) float64 { return float64(g.TempC) }},
{"GPU VRAM temperature in Celsius", "llamaswap_gpu_vram_temperature_celsius", func(g GpuStat) float64 { return float64(g.VramTempC) }},
{"GPU utilization percent (0-100)", "llamaswap_gpu_util_percent", func(g GpuStat) float64 { return g.GpuUtilPct }},
{"GPU memory utilization percent (0-100)", "llamaswap_gpu_memory_util_percent", func(g GpuStat) float64 { return g.MemUtilPct }},
{"GPU memory used in bytes", "llamaswap_gpu_memory_used_bytes", func(g GpuStat) float64 { return float64(g.MemUsedMB) * float64(mbToBytes) }},
{"GPU memory total in bytes", "llamaswap_gpu_memory_total_bytes", func(g GpuStat) float64 { return float64(g.MemTotalMB) * float64(mbToBytes) }},
{"GPU fan speed percent (0-100)", "llamaswap_gpu_fan_speed_percent", func(g GpuStat) float64 { return g.FanSpeedPct }},
{"GPU power draw in watts", "llamaswap_gpu_power_draw_watts", func(g GpuStat) float64 { return g.PowerDrawW }},
}
for _, m := range metrics {
fmt.Fprintf(w, "# HELP %s %s\n", m.name, m.help)
fmt.Fprintf(w, "# TYPE %s gauge\n", m.name)
for _, g := range gpus {
if g.UUID != "" {
fmt.Fprintf(w, "%s{id=\"%d\",name=\"%s\",uuid=\"%s\"} %g\n",
m.name, g.ID, sanitizeLabel(g.Name), sanitizeLabel(g.UUID), m.value(g))
} else {
fmt.Fprintf(w, "%s{id=\"%d\",name=\"%s\"} %g\n",
m.name, g.ID, sanitizeLabel(g.Name), m.value(g))
}
}
}
}
// latestPerGPU returns the most recent GpuStat for each GPU ID, sorted by ID.
func latestPerGPU(stats []GpuStat) []GpuStat {
latest := make(map[int]GpuStat)
for _, g := range stats {
if prev, ok := latest[g.ID]; !ok || g.Timestamp.After(prev.Timestamp) {
latest[g.ID] = g
}
}
result := make([]GpuStat, 0, len(latest))
for _, g := range latest {
result = append(result, g)
}
sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID })
return result
}
// sanitizeLabel escapes characters that are invalid in Prometheus label values.
func sanitizeLabel(s string) string {
return strings.NewReplacer(`"`, `\"`, `\`, `\\`, "\n", `\n`).Replace(s)
}
+248
View File
@@ -0,0 +1,248 @@
package perf
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSanitizeLabel(t *testing.T) {
tests := []struct {
input string
want string
}{
{"normal", "normal"},
{"", ""},
{`with"quote`, `with\"quote`},
{`with\backslash`, `with\\backslash`},
{"with\nnewline", `with\nnewline`},
{`"both\n"`, `\"both\\n\"`},
}
for _, tc := range tests {
assert.Equal(t, tc.want, sanitizeLabel(tc.input), "input: %q", tc.input)
}
}
func TestLatestPerGPU_Empty(t *testing.T) {
result := latestPerGPU(nil)
assert.Empty(t, result)
}
func TestLatestPerGPU_Single(t *testing.T) {
now := time.Now()
stats := []GpuStat{{ID: 0, Name: "gpu0", Timestamp: now}}
result := latestPerGPU(stats)
require.Len(t, result, 1)
assert.Equal(t, "gpu0", result[0].Name)
}
func TestLatestPerGPU_PicksLatest(t *testing.T) {
earlier := time.Now().Add(-time.Second)
later := time.Now()
stats := []GpuStat{
{ID: 0, Name: "old", TempC: 50, Timestamp: earlier},
{ID: 0, Name: "new", TempC: 70, Timestamp: later},
}
result := latestPerGPU(stats)
require.Len(t, result, 1)
assert.Equal(t, "new", result[0].Name)
assert.Equal(t, 70, result[0].TempC)
}
func TestLatestPerGPU_MultipleGPUsSortedByID(t *testing.T) {
now := time.Now()
stats := []GpuStat{
{ID: 2, Name: "gpu2", Timestamp: now},
{ID: 0, Name: "gpu0", Timestamp: now},
{ID: 1, Name: "gpu1", Timestamp: now},
}
result := latestPerGPU(stats)
require.Len(t, result, 3)
assert.Equal(t, 0, result[0].ID)
assert.Equal(t, 1, result[1].ID)
assert.Equal(t, 2, result[2].ID)
}
func TestWriteSysMetrics(t *testing.T) {
rec := httptest.NewRecorder()
s := SysStat{
CpuUtilPerCore: []float64{10.5, 20.0},
MemTotalMB: 8192,
MemUsedMB: 4096,
MemFreeMB: 4096,
SwapTotalMB: 2048,
SwapUsedMB: 512,
LoadAvg1: 1.5,
LoadAvg5: 1.2,
LoadAvg15: 0.9,
NetIO: []NetIOStat{
{Name: "eth0", BytesRecv: 1000, BytesSent: 2000},
},
}
writeSysMetrics(rec, s)
body := rec.Body.String()
assert.Contains(t, body, `llamaswap_cpu_util_percent{core="0"} 10.5`)
assert.Contains(t, body, `llamaswap_cpu_util_percent{core="1"} 20`)
assert.Contains(t, body, "llamaswap_memory_total_bytes 8589934592")
assert.Contains(t, body, "llamaswap_memory_used_bytes 4294967296")
assert.Contains(t, body, "llamaswap_memory_free_bytes 4294967296")
assert.Contains(t, body, "llamaswap_swap_total_bytes 2147483648")
assert.Contains(t, body, "llamaswap_swap_used_bytes 536870912")
assert.Contains(t, body, `llamaswap_load_average{interval="1m"} 1.5`)
assert.Contains(t, body, `llamaswap_load_average{interval="5m"} 1.2`)
assert.Contains(t, body, `llamaswap_load_average{interval="15m"} 0.9`)
assert.Contains(t, body, `llamaswap_network_bytes_total{interface="eth0",direction="recv"} 1000`)
assert.Contains(t, body, `llamaswap_network_bytes_total{interface="eth0",direction="sent"} 2000`)
}
func TestWriteSysMetrics_NoNetIO(t *testing.T) {
rec := httptest.NewRecorder()
writeSysMetrics(rec, SysStat{CpuUtilPerCore: []float64{5.0}})
body := rec.Body.String()
assert.NotContains(t, body, "llamaswap_network_bytes_total")
}
func TestWriteGpuMetrics_Empty(t *testing.T) {
rec := httptest.NewRecorder()
writeGpuMetrics(rec, nil)
assert.Empty(t, rec.Body.String())
}
func TestWriteGpuMetrics(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{
{
ID: 0,
Name: "NVIDIA RTX 4090",
UUID: "GPU-1234",
TempC: 75,
GpuUtilPct: 85.5,
MemUtilPct: 60.0,
MemUsedMB: 8192,
MemTotalMB: 24576,
FanSpeedPct: 55.0,
PowerDrawW: 300.5,
},
}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.Contains(t, body, `llamaswap_gpu_temperature_celsius{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 75`)
assert.Contains(t, body, `llamaswap_gpu_vram_temperature_celsius{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 0`)
assert.Contains(t, body, `llamaswap_gpu_util_percent{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 85.5`)
assert.Contains(t, body, `llamaswap_gpu_memory_util_percent{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 60`)
assert.Contains(t, body, `llamaswap_gpu_memory_used_bytes{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"}`)
assert.Contains(t, body, `llamaswap_gpu_memory_total_bytes{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"}`)
assert.Contains(t, body, `llamaswap_gpu_fan_speed_percent{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 55`)
assert.Contains(t, body, `llamaswap_gpu_power_draw_watts{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 300.5`)
}
func TestWriteGpuMetrics_VramTemp(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{
{ID: 0, Name: "AMD RX 7900", UUID: "GPU-5678", TempC: 70, VramTempC: 85},
}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.Contains(t, body, `llamaswap_gpu_temperature_celsius{id="0",name="AMD RX 7900",uuid="GPU-5678"} 70`)
assert.Contains(t, body, `llamaswap_gpu_vram_temperature_celsius{id="0",name="AMD RX 7900",uuid="GPU-5678"} 85`)
}
func TestWriteGpuMetrics_EmptyUUID(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{{ID: 3, Name: "AMD RX 7900", UUID: ""}}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.NotContains(t, body, "uuid=")
assert.Contains(t, body, `name="AMD RX 7900"`)
}
func TestWriteGpuMetrics_LabelSanitization(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{
{ID: 0, Name: `GPU "special"`, UUID: "uuid\nline"},
}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.Contains(t, body, `name="GPU \"special\""`)
assert.Contains(t, body, `uuid="uuid\nline"`)
}
func TestMetricsHandler_ContentType(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
assert.Equal(t, "text/plain; version=0.0.4; charset=utf-8", rec.Header().Get("Content-Type"))
}
func TestMetricsHandler_EmptyStats(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Empty(t, strings.TrimSpace(rec.Body.String()))
}
func TestMetricsHandler_WithSysStats(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
m.sysRing.Push(SysStat{Timestamp: time.Now(), CpuUtilPerCore: []float64{25.0}, MemTotalMB: 4096, MemUsedMB: 2048, MemFreeMB: 2048})
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
body := rec.Body.String()
assert.Contains(t, body, "llamaswap_cpu_util_percent")
assert.Contains(t, body, "llamaswap_memory_total_bytes")
}
func TestMetricsHandler_UsesLatestSysStat(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
now := time.Now()
m.sysRing.Push(SysStat{Timestamp: now.Add(-time.Second), MemTotalMB: 1000})
m.sysRing.Push(SysStat{Timestamp: now, MemTotalMB: 8192})
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
body := rec.Body.String()
// 8192 MB = 8589934592 bytes
assert.Contains(t, body, "llamaswap_memory_total_bytes 8589934592")
}
func TestMetricsHandler_WithGpuStats(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
m.gpuRing.Push([]GpuStat{{ID: 0, Name: "TestGPU", UUID: "uuid-0", TempC: 65, Timestamp: time.Now()}})
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
body := rec.Body.String()
assert.Contains(t, body, "llamaswap_gpu_temperature_celsius")
assert.Contains(t, body, `name="TestGPU"`)
}
+40
View File
@@ -0,0 +1,40 @@
package perf
import "time"
type GpuStat struct {
Timestamp time.Time `json:"timestamp"`
ID int `json:"id"`
Name string `json:"name"`
UUID string `json:"uuid"`
TempC int `json:"temp_c"`
VramTempC int `json:"vram_temp_c"`
GpuUtilPct float64 `json:"gpu_util_pct"`
MemUtilPct float64 `json:"mem_util_pct"`
MemUsedMB int `json:"mem_used_mb"`
MemTotalMB int `json:"mem_total_mb"`
FanSpeedPct float64 `json:"fan_speed_pct"`
PowerDrawW float64 `json:"power_draw_w"`
}
type NetIOStat struct {
Name string `json:"name"`
BytesRecv uint64 `json:"bytes_recv"`
BytesSent uint64 `json:"bytes_sent"`
}
type SysStat struct {
Timestamp time.Time `json:"timestamp"`
CpuUtilPerCore []float64 `json:"cpu_util_per_core"`
MemTotalMB int `json:"mem_total_mb"`
MemUsedMB int `json:"mem_used_mb"`
MemFreeMB int `json:"mem_free_mb"`
SwapTotalMB int `json:"swap_total_mb"`
SwapUsedMB int `json:"swap_used_mb"`
LoadAvg1 float64 `json:"load_avg_1"`
LoadAvg5 float64 `json:"load_avg_5"`
LoadAvg15 float64 `json:"load_avg_15"`
NetIO []NetIOStat `json:"net_io"`
}
+49
View File
@@ -0,0 +1,49 @@
package process
import (
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var simpleResponderPath string
func skipIfNoSimpleResponder(t *testing.T) {
t.Helper()
if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) {
t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath)
}
}
func TestMain(m *testing.M) {
goos := runtime.GOOS
goarch := runtime.GOARCH
if goos == "windows" {
simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe")
} else {
simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
m.Run()
}
func getFreePort(t *testing.T) int {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("getFreePort: %v", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}
func simpleResponderCmd(t *testing.T, args ...string) (string, int) {
port := getFreePort(t)
cmdPath := filepath.ToSlash(simpleResponderPath)
base := []string{cmdPath, fmt.Sprintf("-port %d", port)}
base = append(base, args...)
return strings.Join(base, " "), port
}
+49
View File
@@ -0,0 +1,49 @@
package process
import (
"context"
"net/http"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type Process interface {
// Run starts the process blocks until the process is terminated.
// The timeout parameter controls how long to wait for the process to get
// to a ready state to process traffic
Run(timeout time.Duration) error
// WaitReady blocks until the process is ready to serve requests
// or the context is cancelled. It returns nil when the process is ready
WaitReady(context.Context) error
// Stop blocks until the process has terminated. It returns nil when
// the process terminated as expected (exit 0)
Stop(timeout time.Duration) error
// State returns the current state of the process
// Note: this is a snapshot of the state at the time of the call
// and may change at any time after the call returns.
State() ProcessState
// ServeHTTP forwards requests to the underlying process
// Calling it when the process is not ready will result in a
// 503 response with a body indicating it is a llama-swap-error
ServeHTTP(http.ResponseWriter, *http.Request)
// Logger returns the monitor that captures this process's stdout/stderr.
Logger() *logmon.Monitor
}
+679
View File
@@ -0,0 +1,679 @@
package process
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"os/exec"
"strings"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
)
var ErrStartAborted = fmt.Errorf("aborted")
// cmdWaitDelay is the upper bound the runtime will wait for child I/O to
// drain after the process exits before force-closing the stdout/stderr
// pipes. Required so that cmd.Wait() returns even when a forked grandchild
// inherits and holds the pipes open (e.g. a shell wrapper that backgrounds
// the real binary). killProcess sends the stop signal directly (not via the
// cmd context), so this delay is measured from process exit rather than from
// the stop request, and stays independent of the caller's graceful timeout.
const cmdWaitDelay = 10 * time.Second
// parentCancelGraceTimeout is the graceful timeout used when the process is
// torn down because parentCtx was cancelled (final router teardown or app
// shutdown). In the normal flow the process has already been stopped via
// Stop() by this point, so killProcess is a no-op kill; the short grace just
// bounds the rare case where a process is still alive when its context is cut.
const parentCancelGraceTimeout = time.Second
type runReq struct {
timeout time.Duration
respond chan error
}
type stopReq struct {
timeout time.Duration
respond chan error
}
type waitReadyReq struct {
respond chan error
}
type startResult struct {
cmd *exec.Cmd
cmdDone chan struct{}
cancel context.CancelFunc
handlerFn http.HandlerFunc
err error
}
type ProcessCommand struct {
id string
config config.ModelConfig
parentCtx context.Context
processLogger *logmon.Monitor
proxyLogger *logmon.Monitor
// waitDelay is assigned to cmd.WaitDelay when starting the upstream
// process. Defaults to cmdWaitDelay; tests override it to keep the
// pipe-close backstop from dominating their runtime.
waitDelay time.Duration
runCh chan runReq
stopCh chan stopReq
waitReadyCh chan waitReadyReq
// current ProcessState. Written only by run(); read by State() via atomic load.
state atomic.Value
// stores the active reverse-proxy handler when the process is running.
// Written only by run(); read by ServeHTTP via atomic load.
handler atomic.Pointer[http.HandlerFunc]
lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion
inflight atomic.Int64 // current in-flight ServeHTTP calls
}
var _ Process = (*ProcessCommand)(nil)
func New(
parentCtx context.Context,
id string,
conf config.ModelConfig,
processLogger *logmon.Monitor,
proxyLogger *logmon.Monitor,
) (*ProcessCommand, error) {
p := &ProcessCommand{
id: id,
config: conf,
parentCtx: parentCtx,
processLogger: processLogger,
proxyLogger: proxyLogger,
runCh: make(chan runReq),
stopCh: make(chan stopReq),
waitReadyCh: make(chan waitReadyReq),
waitDelay: cmdWaitDelay,
}
p.state.Store(StateStopped)
go p.run()
return p, nil
}
func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger }
// run is the single-writer goroutine that owns all mutable lifecycle state
// (current ProcessState, the running *exec.Cmd, the active reverse-proxy
// handler, and the list of WaitReady subscribers). Every public method
// (Run / Stop / State / WaitReady) is a thin client that sends a request on
// one of the channels below and waits for a response — this funnels concurrent
// callers through a single serialization point so the state machine never
// observes a race.
func (p *ProcessCommand) run() {
// Mutable state — only read/written from this goroutine. ServeHTTP reads
// p.handler concurrently, which is why handler is an atomic.Pointer.
// p.state mirrors `state` so State() can observe transitions; setState
// writes both.
state := StateStopped
setState := func(s ProcessState) {
old := state
state = s
p.state.Store(s)
if old != s {
event.Emit(shared.ProcessStateChangeEvent{
ProcessName: p.id,
OldState: string(old),
NewState: string(s),
})
}
}
var (
cmd *exec.Cmd
cmdDone <-chan struct{}
cmdCancel context.CancelFunc
readyWaiters []waitReadyReq
// runResp parks the in-flight Run caller's response channel. The
// interface contract is that Run blocks until the process is
// terminated, so we hold this until Stop, parentCtx, or an
// upstream exit unblocks it via respondRun.
runResp chan<- error
)
// notifyWaiters wakes every blocked WaitReady caller with the given result.
// Used on transitions out of StateStarting (ready, failed, aborted, or
// shutdown) — anything that resolves the "is it ready yet?" question.
notifyWaiters := func(err error) {
for _, w := range readyWaiters {
select {
case w.respond <- err:
default:
}
}
readyWaiters = nil
}
// respondRun delivers the final Run result, if a Run caller is parked.
respondRun := func(err error) {
if runResp != nil {
runResp <- err
runResp = nil
}
}
for {
select {
// Shutdown: parent context cancelled. Tear down any running process,
// wake any pending WaitReady callers with an error, then exit the
// goroutine permanently. Subsequent public-method calls will fail
// because parentCtx.Done() unblocks their send-side selects.
case <-p.parentCtx.Done():
// Mark shutdown before killProcess so concurrent State() readers
// stop treating this process as ready while the (possibly slow)
// teardown is in progress.
setState(StateShutdown)
if cmd != nil {
p.handler.Store(nil)
p.killProcess(cmd, cmdCancel, cmdDone, parentCancelGraceTimeout)
cmd = nil
cmdDone = nil
cmdCancel = nil
}
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
respondRun(fmt.Errorf("[%s] shutdown", p.id))
return
// Upstream exited on its own (not via Stop). Drop handler state,
// transition to Stopped, and unblock the parked Run caller.
// cmdDone is nil while no process is running, so this case is
// dormant outside of StateReady.
case <-cmdDone:
if cmdCancel != nil {
cmdCancel()
}
cmd = nil
cmdDone = nil
cmdCancel = nil
p.handler.Store(nil)
setState(StateStopped)
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
// WaitReady: if we're already in a terminal-for-this-question state,
// respond immediately; otherwise queue the caller and let a future
// state transition wake them via notifyWaiters.
case req := <-p.waitReadyCh:
switch state {
case StateReady:
req.respond <- nil
case StateShutdown:
req.respond <- fmt.Errorf("[%s] shutdown", p.id)
default:
readyWaiters = append(readyWaiters, req)
}
// Run: start the upstream process. Only valid from StateStopped.
// doStart can take a long time (health-check polling), so it runs in
// a separate goroutine and we wait on resultCh. While waiting we also
// listen for an incoming Stop — that's how callers cancel an in-flight
// start.
case req := <-p.runCh:
if state != StateStopped {
req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state)
continue
}
setState(StateStarting)
startCtx, cancelStart := context.WithCancel(context.Background())
resultCh := make(chan startResult, 1)
go func() {
resultCh <- p.doStart(startCtx, req.timeout)
}()
// pendingStop holds a Stop request that arrived mid-start, so we
// can respond to it AFTER we've finished tearing the start down.
var pendingStop *stopReq
select {
// doStart finished on its own — either successfully (latch
// cmd/handler and move to Ready) or with an error (back to
// Stopped). Either way wake WaitReady subscribers and reply
// to the Run caller.
case res := <-resultCh:
if res.err == nil {
cmd = res.cmd
cmdDone = res.cmdDone
cmdCancel = res.cancel
fn := res.handlerFn
p.handler.Store(&fn)
setState(StateReady)
notifyWaiters(nil)
// Park the Run response — Run blocks until the process
// terminates, so we only fire this when Stop, parentCtx,
// or the upstream exit takes the process down.
runResp = req.respond
// Start TTL goroutine if configured — self-terminates
// when state leaves StateReady.
if p.config.UnloadAfter > 0 {
ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
if p.State() != StateReady {
return
}
if p.inflight.Load() != 0 {
continue
}
if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter)
p.Stop(10 * time.Second)
return
}
}
}()
}
} else {
setState(StateStopped)
notifyWaiters(res.err)
req.respond <- res.err
}
// Stop arrived while doStart was still running. Cancel the
// start context to abort it, then wait for doStart to return.
// If doStart had already crossed the finish line before
// cancellation took effect, it returns a live cmd that we
// must kill ourselves. The Run caller gets ErrAbort; the Stop
// caller is parked in pendingStop and answered below.
case stop := <-p.stopCh:
cancelStart()
res := <-resultCh
if res.cmd != nil {
p.killProcess(res.cmd, res.cancel, res.cmdDone, stop.timeout)
}
setState(StateStopped)
notifyWaiters(ErrStartAborted)
req.respond <- ErrStartAborted
pendingStop = &stop
// Parent context cancelled (e.g. config reload) while doStart
// was still running. Stop() returns early when parentCtx is
// done and never sends on stopCh, so we must handle shutdown
// here to avoid leaving doStart running indefinitely.
case <-p.parentCtx.Done():
cancelStart()
// Mark shutdown before tearing the process down: killProcess
// may block (e.g. taskkill on Windows is slow to spawn), and
// callers observing State() should see StateShutdown promptly
// rather than a stale StateStarting.
setState(StateShutdown)
res := <-resultCh
if res.cmd != nil {
p.killProcess(res.cmd, res.cancel, res.cmdDone, parentCancelGraceTimeout)
}
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
respondRun(fmt.Errorf("[%s] shutdown", p.id))
return
}
// cancelStart is idempotent; calling it again here ensures the
// context is released even on the success path (govet leak check).
cancelStart()
if pendingStop != nil {
pendingStop.respond <- nil
}
// Stop: tear down a running process.
case stop := <-p.stopCh:
if cmd != nil {
setState(StateStopping)
p.killProcess(cmd, cmdCancel, cmdDone, stop.timeout)
cmd = nil
cmdDone = nil
cmdCancel = nil
p.handler.Store(nil)
}
// Stop is a no-op (and not an error) when already Stopped — this
// is what makes it idempotent for callers that don't track state.
setState(StateStopped)
respondRun(nil)
stop.respond <- nil
}
}
}
func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult {
if p.config.Proxy == "" {
return startResult{err: fmt.Errorf("upstream proxy missing")}
}
args, err := p.config.SanitizedCommand()
if err != nil {
return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)}
}
proxyURL, err := url.Parse(p.config.Proxy)
if err != nil {
return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)}
}
reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL)
reverseProxy.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second,
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
// httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream
// disconnects after response headers have been sent. Recover here so the
// streaming termination is treated as a normal client/upstream disconnect.
// see: https://github.com/golang/go/issues/23643
handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
if rec == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id)
} else {
p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec)
}
}
}()
reverseProxy.ServeHTTP(w, r)
})
// cmdCtx + cmd.Cancel are wired as a safety net: if the context is ever
// cancelled while the process is alive, cmd.Cancel sends SIGTERM / CmdStop
// and the runtime escalates to SIGKILL after cmd.WaitDelay. In the normal
// teardown path killProcess sends the stop signal directly instead, so
// cmd.WaitDelay only acts as the inherited-pipe backstop measured from
// process exit (see killProcess).
cmdCtx, cmdCancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
cmd.Stderr = p.processLogger
cmd.Stdout = p.processLogger
cmd.Env = append(cmd.Environ(), p.config.Env...)
cmd.Cancel = func() error { return p.sendStopSignal(cmd) }
cmd.WaitDelay = p.waitDelay
setProcAttributes(cmd)
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
cmdDone := make(chan struct{})
if err := cmd.Start(); err != nil {
cmdCancel()
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
}
go func() {
waitErr := cmd.Wait()
switch st := p.State(); {
case waitErr == nil:
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
case st == StateStopping || st == StateShutdown:
// Expected: we force-terminated the process. A forced kill exits
// the child with a non-zero code (e.g. taskkill /f on Windows
// yields exit status 1), so this is not an error.
p.proxyLogger.Debugf("<%s> process stopped by llama-swap: %v", p.id, waitErr)
default:
if exitErr, ok := waitErr.(*exec.ExitError); ok {
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
} else {
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
}
}
close(cmdDone)
}()
abort := func(err error) startResult {
p.killProcess(cmd, cmdCancel, cmdDone, 5*time.Second)
return startResult{err: err}
}
prematureExit := func() startResult {
cmdCancel()
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
}
if startCtx.Err() != nil {
return abort(ErrStartAborted)
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
}
// Wait 250ms for the command to start up before health checking
select {
case <-startCtx.Done():
return abort(ErrStartAborted)
case <-time.After(250 * time.Millisecond):
}
deadline := time.Now().Add(healthCheckTimeout)
for {
select {
case <-startCtx.Done():
return abort(ErrStartAborted)
case <-cmdDone:
return prematureExit()
default:
}
if time.Now().After(deadline) {
return abort(fmt.Errorf("health check timed out after %v", healthCheckTimeout))
}
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
rr := httptest.NewRecorder()
reverseProxy.ServeHTTP(rr, req)
resp := rr.Result()
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
break
} else if startCtx.Err() != nil {
return abort(ErrStartAborted)
}
select {
case <-startCtx.Done():
return abort(ErrStartAborted)
case <-cmdDone:
return prematureExit()
case <-time.After(time.Second):
}
}
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
}
// sendStopSignal runs the configured CmdStop (if any) or sends SIGTERM to
// the upstream process. Wired up as cmd.Cancel so it fires whenever the
// cmd's context is cancelled.
func (p *ProcessCommand) sendStopSignal(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
p.processLogger.Debugf("<%s> sendStopSignal() called with nil cmd or process, nothing to stop", p.id)
return nil
}
pid := cmd.Process.Pid
if p.config.CmdStop != "" {
p.processLogger.Debugf("<%s> sendStopSignal() using CmdStop %q for pid %d", p.id, p.config.CmdStop, pid)
stopArgs, err := config.SanitizeCommand(
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", pid)),
)
if err == nil {
p.processLogger.Debugf("<%s> sendStopSignal() running stop command: %s", p.id, strings.Join(stopArgs, " "))
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Env = cmd.Env
setProcAttributes(stopCmd)
runErr := stopCmd.Run()
if runErr != nil {
p.processLogger.Errorf("<%s> sendStopSignal() stop command failed: %v", p.id, runErr)
} else {
p.processLogger.Debugf("<%s> sendStopSignal() stop command completed for pid %d", p.id, pid)
}
return runErr
}
// fall through to SIGTERM if sanitize failed
p.processLogger.Errorf("<%s> sendStopSignal() failed to sanitize CmdStop %q: %v, falling back to terminateProcessTree", p.id, p.config.CmdStop, err)
}
// On Unix this SIGTERMs the whole process group so a forked grandchild
// (e.g. a shell wrapper that backgrounds the real binary) is taken down
// with the parent rather than orphaned.
p.processLogger.Debugf("<%s> sendStopSignal() no CmdStop configured, calling terminateProcessTree for pid %d", p.id, pid)
termErr := terminateProcessTree(cmd)
if termErr != nil {
p.processLogger.Errorf("<%s> sendStopSignal() terminateProcessTree failed for pid %d: %v", p.id, pid, termErr)
}
return termErr
}
// killProcess terminates the upstream process. The flow:
//
// 1. Send the graceful stop signal (CmdStop / SIGTERM) directly — NOT by
// cancelling cmdCtx. Cancelling the context would start cmd.WaitDelay
// immediately, which force-kills the process WaitDelay after the signal
// and would silently cap gracefulTimeout at WaitDelay whenever
// gracefulTimeout is the longer of the two.
// 2. We wait up to gracefulTimeout for the process to exit on its own.
// 3. If still alive, we SIGKILL the process group directly (Unix) so any
// forked descendant is force-terminated alongside the parent.
// 4. We wait on cmdDone. cmd.WaitDelay (set when the cmd was built) is the
// critical backstop here: once the process exits, if a forked grandchild
// inherited the stdout/stderr pipes and is still holding them, the runtime
// force-closes the pipes WaitDelay after the exit and cmd.Wait() unblocks.
// Because we never cancelled the context, that WaitDelay timer measures
// from process exit (see os/exec awaitGoroutines), not from this call.
// Without WaitDelay this select would hang forever (the v219 bug).
//
// cancel() is still invoked (deferred) to release the context, but only after
// the process has exited and os/exec's ctx watcher has already torn down, so it
// never re-fires cmd.Cancel.
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cancel context.CancelFunc, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
if cancel == nil {
return
}
defer cancel()
// Deliver CmdStop / SIGTERM in a goroutine so a slow or hanging CmdStop
// cannot block the run() goroutine; the gracefulTimeout + Process.Kill
// path below still guarantees teardown.
if cmd != nil {
go func() { _ = p.sendStopSignal(cmd) }()
}
timer := time.NewTimer(gracefulTimeout)
defer timer.Stop()
select {
case <-cmdDone:
return
case <-timer.C:
}
if cmd != nil {
// SIGKILL the whole process group on Unix so any descendant that
// ignored or outlived the graceful signal is force-terminated too.
_ = killProcessTree(cmd)
}
<-cmdDone
}
func (p *ProcessCommand) ID() string {
return p.id
}
func (p *ProcessCommand) Run(timeout time.Duration) error {
req := runReq{
timeout: timeout,
respond: make(chan error, 1),
}
select {
case p.runCh <- req:
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
select {
case err := <-req.respond:
return err
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
}
func (p *ProcessCommand) WaitReady(ctx context.Context) error {
req := waitReadyReq{respond: make(chan error, 1)}
select {
case p.waitReadyCh <- req:
case <-ctx.Done():
return ctx.Err()
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
select {
case err := <-req.respond:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (p *ProcessCommand) Stop(timeout time.Duration) error {
req := stopReq{
timeout: timeout,
respond: make(chan error, 1),
}
select {
case p.stopCh <- req:
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
return <-req.respond
}
func (p *ProcessCommand) State() ProcessState {
if s, ok := p.state.Load().(ProcessState); ok {
return s
}
return StateStopped
}
func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fn := p.handler.Load()
if fn == nil {
http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable)
return
}
p.inflight.Add(1)
defer func() {
p.lastUse.Store(time.Now().UnixNano())
p.inflight.Add(-1)
}()
(*fn)(w, r)
}
@@ -0,0 +1,262 @@
//go:build !windows
package process
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
)
// TestProcessCommand_StopForkingWrapper is a regression for the bug reported
// against v219 where Stop would hang indefinitely when the upstream command
// is a shell wrapper that forks the real binary (e.g. `#!/bin/bash` then
// `"$@"`). After SIGTERM the wrapper dies but the grandchild inherits the
// stdout/stderr pipes; cmd.Wait() blocks waiting for the pipe-copy goroutine
// to drain EOF, which never happens while the grandchild holds the fds.
//
// The fix is cmd.WaitDelay (combined with exec.CommandContext + cmd.Cancel),
// which causes the runtime to force-close the pipes after the delay so
// cmd.Wait() — and therefore Stop — returns.
func TestProcessCommand_StopForkingWrapper(t *testing.T) {
skipIfNoSimpleResponder(t)
port := getFreePort(t)
dir := t.TempDir()
pidFile := filepath.Join(dir, "child.pid")
// Wrapper script: backgrounds the child (which inherits stdout/stderr),
// records its PID for cleanup, then waits. When SIGTERM hits bash it
// dies without forwarding the signal; the grandchild keeps running and
// keeps the inherited pipe fds open. This is the scenario reported in
// the v219 regression.
wrapper := filepath.Join(dir, "wrapper.sh")
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
simpleResponderPath, port, pidFile)
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
t.Cleanup(func() { killChildFromPidFile(pidFile) })
p := newProcessCommand(t, config.ModelConfig{
Cmd: wrapper,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
// Shrink the pipe-close backstop so the test doesn't sit at the
// production default (10s). Must be set before Run() so doStart picks
// it up when building the cmd.
const testWaitDelay = 250 * time.Millisecond
p.waitDelay = testWaitDelay
runErr := runAsync(t, p)
// Stop must return within a bounded time even though the grandchild
// is still holding the pipe open. Budget is generous on top of
// testWaitDelay to absorb scheduling jitter on slow CI runners; the
// pre-fix behaviour was an unbounded hang, so any reasonable cap
// distinguishes pass from fail.
stopReturned := make(chan error, 1)
stopStart := time.Now()
go func() { stopReturned <- p.Stop(testStopTimeout) }()
const stopBudget = testWaitDelay + 2*time.Second
select {
case err := <-stopReturned:
if err != nil {
t.Fatalf("Stop: %v", err)
}
t.Logf("Stop returned in %v", time.Since(stopStart))
case <-time.After(stopBudget):
t.Fatalf("Stop did not return within %v — cmd.Wait() likely hung on inherited pipe", stopBudget)
}
if got := p.State(); got != StateStopped {
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Errorf("Run did not return after Stop")
}
}
// TestProcessCommand_StopHonorsGracefulTimeout is a regression for the bug
// where cmd.WaitDelay capped the graceful shutdown window. killProcess used to
// cancel the cmd context to deliver SIGTERM, which starts cmd.WaitDelay
// immediately; a process whose SIGTERM handler needs longer than WaitDelay to
// finish was force-killed early even though Stop was given a much longer
// timeout. The fix sends the signal directly so WaitDelay measures from process
// exit (its inherited-pipe backstop role), leaving the graceful window to the
// caller's Stop timeout.
func TestProcessCommand_StopHonorsGracefulTimeout(t *testing.T) {
dir := t.TempDir()
marker := filepath.Join(dir, "graceful.done")
ready := filepath.Join(dir, "trap.ready")
// On SIGTERM, sleep past the (short) WaitDelay, then write the marker and
// exit cleanly. If WaitDelay still drove the kill, bash would be SIGKILLed
// mid-handler and the marker would never be written. The ready file is
// written only after the trap is installed so the test does not race
// SIGTERM ahead of it (CheckEndpoint:none marks ready before bash runs).
script := filepath.Join(dir, "graceful.sh")
body := fmt.Sprintf(
"#!/bin/bash\ncleanup() { sleep 0.6; echo done > %q; exit 0; }\ntrap cleanup SIGTERM\necho ready > %q\nwhile true; do sleep 0.1; done\n",
marker, ready,
)
if err := os.WriteFile(script, []byte(body), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
p := newProcessCommand(t, config.ModelConfig{
Cmd: script,
Proxy: "http://127.0.0.1:1", // unused: health check disabled
CheckEndpoint: "none",
})
// WaitDelay shorter than the handler's 0.6s sleep, and far shorter than the
// Stop timeout below — this is the window the old code mis-killed in.
p.waitDelay = 200 * time.Millisecond
runErr := runAsync(t, p)
// Wait until the trap is installed before stopping.
trapDeadline := time.Now().Add(2 * time.Second)
for {
if _, err := os.Stat(ready); err == nil {
break
}
if time.Now().After(trapDeadline) {
t.Fatalf("script did not install SIGTERM trap in time")
}
time.Sleep(10 * time.Millisecond)
}
stopStart := time.Now()
if err := p.Stop(5 * time.Second); err != nil {
t.Fatalf("Stop: %v", err)
}
elapsed := time.Since(stopStart)
// The handler must have run to completion (marker written) rather than
// being force-killed at waitDelay.
if _, err := os.Stat(marker); err != nil {
t.Fatalf("graceful handler did not complete (marker missing): %v", err)
}
// And Stop must have waited for the handler (>~0.6s), not returned at the
// 200ms waitDelay.
if elapsed < 500*time.Millisecond {
t.Fatalf("Stop returned in %v — process was killed before its graceful handler finished", elapsed)
}
if got := p.State(); got != StateStopped {
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Errorf("Run did not return after Stop")
}
}
// TestProcessCommand_StopReapsForkedGrandchild verifies that stopping a forking
// wrapper takes down the backgrounded grandchild too, rather than leaving it as
// an orphan. The fix is Setpgid (runtime_unix.go): the wrapper leads its own
// process group, so the stop signal is delivered to the whole group via the
// negative PID and reaches the grandchild the wrapper never reaped.
func TestProcessCommand_StopReapsForkedGrandchild(t *testing.T) {
skipIfNoSimpleResponder(t)
port := getFreePort(t)
dir := t.TempDir()
pidFile := filepath.Join(dir, "child.pid")
wrapper := filepath.Join(dir, "wrapper.sh")
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
simpleResponderPath, port, pidFile)
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
t.Cleanup(func() { killChildFromPidFile(pidFile) })
p := newProcessCommand(t, config.ModelConfig{
Cmd: wrapper,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
runErr := runAsync(t, p)
// Read the grandchild PID the wrapper recorded.
var childPID int
deadline := time.Now().Add(2 * time.Second)
for {
data, err := os.ReadFile(pidFile)
if err == nil {
if pid, perr := strconv.Atoi(strings.TrimSpace(string(data))); perr == nil && pid > 0 {
childPID = pid
break
}
}
if time.Now().After(deadline) {
t.Fatalf("wrapper did not record grandchild PID")
}
time.Sleep(10 * time.Millisecond)
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop: %v", err)
}
// After Stop the grandchild must be gone. Signal 0 probes liveness without
// actually sending a signal; give it a brief window to exit after the
// group SIGTERM.
proc, err := os.FindProcess(childPID)
if err != nil {
t.Fatalf("FindProcess: %v", err)
}
gone := false
for i := 0; i < 100; i++ {
if err := proc.Signal(syscall.Signal(0)); err != nil {
gone = true
break
}
time.Sleep(10 * time.Millisecond)
}
if !gone {
t.Errorf("grandchild PID %d still alive after Stop — process group was not reaped", childPID)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Errorf("Run did not return after Stop")
}
}
// killChildFromPidFile reads a PID written by the wrapper script and SIGKILLs
// it so leaked orphans don't accumulate between test runs. Best-effort.
func killChildFromPidFile(pidFile string) {
data, err := os.ReadFile(pidFile)
if err != nil {
return
}
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil || pid <= 0 {
return
}
proc, err := os.FindProcess(pid)
if err != nil {
return
}
_ = proc.Kill()
}
+646
View File
@@ -0,0 +1,646 @@
package process
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
const (
testStartTimeout = 3 * time.Second
testStopTimeout = 2 * time.Second
testReturnTimeout = 1 * time.Second
testPollInterval = 20 * time.Millisecond
testLogPollInterval = 10 * time.Millisecond
)
func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand {
t.Helper()
logger := logmon.NewWriter(io.Discard)
p, err := New(context.Background(), t.Name(), conf, logger, logger)
if err != nil {
t.Fatalf("New: %v", err)
}
return p
}
// runAsync starts Run in a goroutine and waits until the process is ready,
// matching the new interface contract where Run blocks until the process is
// terminated. Returns a channel that delivers Run's eventual error.
func runAsync(t *testing.T, p *ProcessCommand) <-chan error {
t.Helper()
ch := make(chan error, 1)
go func() { ch <- p.Run(testStartTimeout) }()
ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout)
defer cancel()
if err := p.WaitReady(ctx); err != nil {
t.Fatalf("WaitReady: %v", err)
}
return ch
}
func TestProcessCommand_StartStop(t *testing.T) {
skipIfNoSimpleResponder(t)
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
t.Cleanup(func() { p.Stop(testStopTimeout) })
req := httptest.NewRequest("GET", "/test", nil)
// before start: no handler
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("before start: expected 503, got %d", rr.Code)
}
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body)
}
runErr := runAsync(t, p)
if got := p.State(); got != StateReady {
t.Errorf("after Run: expected state %s, got %s", StateReady, got)
}
rr = httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("after Run: expected 200, got %d", rr.Code)
}
if body := rr.Body.String(); body != "hello" {
t.Errorf("expected body %q, got %q", "hello", body)
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
if got := p.State(); got != StateStopped {
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
}
select {
case err := <-runErr:
if err != nil {
t.Errorf("Run() after Stop: expected nil, got %v", err)
}
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
// after stop: handler cleared
rr = httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("after stop: expected 503, got %d", rr.Code)
}
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body)
}
}
func TestProcessCommand_Run_Idempotent(t *testing.T) {
skipIfNoSimpleResponder(t)
cmd, port := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
t.Cleanup(func() { p.Stop(testStopTimeout) })
runErr := runAsync(t, p)
if err := p.Run(testStartTimeout); err == nil {
t.Error("second Run() while running: expected error, got nil")
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
}
func TestProcessCommand_Stop_Idempotent(t *testing.T) {
skipIfNoSimpleResponder(t)
cmd, port := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() before Run(): %v", err)
}
runErr := runAsync(t, p)
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("first Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("second Stop() error: %v", err)
}
}
// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is
// executing its health-check loop returns ErrAbort to the Run caller.
//
// A blocking mock HTTP server is used as the proxy so the test can deterministically
// know when doStart is inside the health-check loop before issuing Stop.
func TestProcessCommand_StopCancelsRun(t *testing.T) {
skipIfNoSimpleResponder(t)
healthCheckStarted := make(chan struct{}, 1)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Signal that a health check is in-flight, then block until the client
// cancels (which happens when Stop cancels the start context).
select {
case healthCheckStarted <- struct{}{}:
default:
}
<-r.Context().Done()
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
}))
defer mock.Close()
// simple-responder is the real process; health checks go to the blocking mock.
cmd, _ := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 30,
})
runErrCh := make(chan error, 1)
go func() {
runErrCh <- p.Run(testStartTimeout)
}()
// Block until doStart is actually performing a health check, guaranteeing
// that Run is in-flight when Stop is called.
<-healthCheckStarted
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
if err := <-runErrCh; !errors.Is(err, ErrStartAborted) {
t.Errorf("expected ErrStartAborted from Run, got %v", err)
}
}
// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the
// parent context while doStart is health-checking causes the process to
// transition to StateShutdown promptly, not wait for the health-check timeout.
//
// This is the config-reload race: Stop() returns early when parentCtx is
// already done and never writes to stopCh, so without a parentCtx.Done()
// case in the inner select, the process would keep loading indefinitely.
func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) {
skipIfNoSimpleResponder(t)
healthCheckStarted := make(chan struct{}, 1)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case healthCheckStarted <- struct{}{}:
default:
}
<-r.Context().Done()
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
}))
defer mock.Close()
parentCtx, cancelParent := context.WithCancel(context.Background())
logger := logmon.NewWriter(io.Discard)
cmd, _ := simpleResponderCmd(t, "-silent")
p, err := New(parentCtx, t.Name(), config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 60,
}, logger, logger)
if err != nil {
t.Fatalf("New: %v", err)
}
runErrCh := make(chan error, 1)
go func() { runErrCh <- p.Run(60 * time.Second) }()
<-healthCheckStarted
// Cancel parent context to simulate a config reload tearing down the old server.
cancelParent()
select {
case err := <-runErrCh:
if !strings.Contains(err.Error(), "shutdown") {
t.Errorf("Run error = %v, want shutdown error", err)
}
case <-time.After(5 * time.Second):
t.Fatal("process did not shut down within 5s after parent context cancel during start")
}
// Run() may return before the run() goroutine writes StateShutdown;
// poll briefly to avoid a spurious race in the assertion.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if p.State() == StateShutdown {
break
}
time.Sleep(testPollInterval)
}
if got := p.State(); got != StateShutdown {
t.Errorf("after cancel: expected StateShutdown, got %s", got)
}
}
// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on
// fresh processes to confirm they are reusable.
func TestProcessCommand_RunStopCycle(t *testing.T) {
skipIfNoSimpleResponder(t)
for i := range 3 {
cmd, port := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
runErr := runAsync(t, p)
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code)
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("cycle %d Stop() error: %v", i, err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatalf("cycle %d: Run() did not return after Stop", i)
}
}
}
// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path:
// the upstream responds healthy on /health (so Run completes), then on the
// actual proxied request it hijacks the connection and closes it mid-body.
// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error,
// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover
// and log the disconnect.
//
// Requests are issued through an httptest.NewServer wrapping the process so
// the panic actually fires (httputil only panics on copy errors when the
// request carries http.ServerContextKey, which a real server sets).
//
// see: https://github.com/golang/go/issues/23643
func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) {
skipIfNoSimpleResponder(t)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(http.StatusOK)
return
}
// Send a Content-Length that promises 100 bytes, deliver only a few,
// then slam the connection shut. The reverse proxy will see EOF
// before the body is fully copied and panic with ErrAbortHandler.
hj, ok := w.(http.Hijacker)
if !ok {
t.Errorf("upstream: hijack not supported")
return
}
conn, _, err := hj.Hijack()
if err != nil {
t.Errorf("upstream: hijack: %v", err)
return
}
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial"))
_ = conn.Close()
}))
t.Cleanup(upstream.Close)
// Capture proxy log output so we can assert the recover message was
// emitted by handlerFn.
logBuf := &syncBuffer{}
proxyLogger := logmon.NewWriter(logBuf)
procLogger := logmon.NewWriter(io.Discard)
cmd, _ := simpleResponderCmd(t, "-silent")
p, err := New(context.Background(), t.Name(), config.ModelConfig{
Cmd: cmd,
Proxy: upstream.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
}, procLogger, proxyLogger)
if err != nil {
t.Fatalf("New: %v", err)
}
t.Cleanup(func() { p.Stop(testStopTimeout) })
_ = runAsync(t, p)
// Wrap p in an httptest server so requests get http.ServerContextKey
// automatically — that is what makes httputil.ReverseProxy raise the panic.
front := httptest.NewServer(p)
t.Cleanup(front.Close)
resp, err := http.Get(front.URL + "/disconnect")
if err == nil {
resp.Body.Close()
}
const want = "recovered from upstream disconnection"
deadline := time.Now().Add(testReturnTimeout)
for time.Now().Before(deadline) {
if strings.Contains(logBuf.String(), want) {
return
}
time.Sleep(testLogPollInterval)
}
t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String())
}
// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output.
type syncBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}
func (b *syncBuffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}
func (b *syncBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}
// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL
// automatically stops itself after the idle timeout has elapsed following its
// last request.
func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) {
skipIfNoSimpleResponder(t)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(mock.Close)
cmd, _ := simpleResponderCmd(t, "-silent")
cfg := config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
UnloadAfter: 1, // 1-second TTL
}
if runtime.GOOS == "windows" {
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
}
p := newProcessCommand(t, cfg)
runErr := runAsync(t, p)
defer func() {
if p.State() == StateReady {
p.Stop(testStopTimeout)
}
}()
if got := p.State(); got != StateReady {
t.Fatalf("expected StateReady, got %s", got)
}
// Make one request to prime the last-use timestamp.
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 after request, got %d", rr.Code)
}
// Wait for the TTL goroutine to fire and the process to fully stop.
// Poll for StateStopped directly to avoid racing the StateStopping
// intermediate state that sits between StateReady and StateStopped.
deadline := time.Now().Add(5 * time.Second)
for p.State() != StateStopped && time.Now().Before(deadline) {
time.Sleep(testPollInterval)
}
if got := p.State(); got != StateStopped {
t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped)
}
// Run() should have returned nil (clean stop from TTL).
select {
case err := <-runErr:
if err != nil {
t.Errorf("Run() after TTL stop: expected nil, got %v", err)
}
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after TTL-induced stop")
}
}
// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests
// prevent the TTL goroutine from stopping the process, and that the TTL timer
// resets after each request completes.
func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) {
skipIfNoSimpleResponder(t)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(mock.Close)
cmd, _ := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
UnloadAfter: 1, // 1-second TTL
})
runErr := runAsync(t, p)
defer func() {
if p.State() == StateReady {
p.Stop(testStopTimeout)
}
}()
// Keep sending requests for 1.5s — past the 1s TTL — and verify
// the process never stops while traffic is flowing.
stopAt := time.Now().Add(1500 * time.Millisecond)
for time.Now().Before(stopAt) {
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if p.State() != StateReady {
t.Fatalf("process was stopped during active traffic (state=%s)", p.State())
}
time.Sleep(10 * time.Millisecond)
}
if got := p.State(); got != StateReady {
t.Fatalf("expected StateReady while traffic was active, got %s", got)
}
// Now stop manually to clean up.
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
}
// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not
// spawn a TTL goroutine — the process stays ready until explicitly stopped.
func TestProcessCommand_TTL_ZeroDisables(t *testing.T) {
skipIfNoSimpleResponder(t)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(mock.Close)
cmd, _ := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
UnloadAfter: 0, // disabled
})
runErr := runAsync(t, p)
defer func() {
if p.State() == StateReady {
p.Stop(testStopTimeout)
}
}()
if got := p.State(); got != StateReady {
t.Fatalf("expected StateReady, got %s", got)
}
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 after request, got %d", rr.Code)
}
// No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is
// enough to confirm the process remains ready.
time.Sleep(100 * time.Millisecond)
if got := p.State(); got != StateReady {
t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got)
}
// Cleanly stop.
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
}
// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing
// pairs to exercise the race detector and verify no deadlocks occur.
func TestProcessCommand_ConcurrentRunStop(t *testing.T) {
skipIfNoSimpleResponder(t)
for range 10 {
cmd, port := simpleResponderCmd(t, "-silent")
cfg := config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
}
if runtime.GOOS == "windows" {
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
}
p := newProcessCommand(t, cfg)
runDone := make(chan struct{})
go func() {
defer close(runDone)
p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race
}()
go func() {
p.Stop(testStopTimeout) //nolint: errcheck
}()
// Backstop: the racing Stop may have arrived before Run got on the
// channel (making it a no-op), so keep stopping until Run unblocks.
deadline := time.After(testStartTimeout)
for done := false; !done; {
select {
case <-runDone:
done = true
case <-deadline:
t.Fatal("Run did not return")
case <-time.After(testPollInterval):
p.Stop(testStopTimeout) //nolint: errcheck
}
}
}
}
+82
View File
@@ -0,0 +1,82 @@
package process
import (
"fmt"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/shared"
)
func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) {
skipIfNoSimpleResponder(t)
var mu sync.Mutex
var transitions []shared.ProcessStateChangeEvent
cancel := event.On(func(e shared.ProcessStateChangeEvent) {
if e.ProcessName != t.Name() {
return
}
mu.Lock()
transitions = append(transitions, e)
mu.Unlock()
})
defer cancel()
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
runErr := runAsync(t, p)
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop: %v", err)
}
<-runErr
// Events are delivered asynchronously; give the dispatcher a moment.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
mu.Lock()
n := len(transitions)
mu.Unlock()
if n >= 4 {
break
}
time.Sleep(testPollInterval)
}
mu.Lock()
defer mu.Unlock()
for _, e := range transitions {
if e.OldState == e.NewState {
t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState)
}
}
want := []string{
string(StateStopped) + "->" + string(StateStarting),
string(StateStarting) + "->" + string(StateReady),
string(StateReady) + "->" + string(StateStopping),
string(StateStopping) + "->" + string(StateStopped),
}
got := make([]string, len(transitions))
for i, e := range transitions {
got[i] = e.OldState + "->" + e.NewState
}
if len(got) != len(want) {
t.Fatalf("transitions = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("transitions = %v, want %v", got, want)
}
}
}
+44
View File
@@ -0,0 +1,44 @@
//go:build !windows
package process
import (
"os/exec"
"syscall"
)
// setProcAttributes starts the upstream in its own process group (Setpgid) so
// the entire process tree can be signalled at once via its negative PID. This
// is what lets us reap a forked grandchild — e.g. a shell wrapper that
// backgrounds the real binary and exits — instead of leaking it as an orphan
// that holds the inherited stdout/stderr pipes open.
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
// terminateProcessTree sends SIGTERM to the whole process group led by the
// command, giving every process in the tree a chance to shut down gracefully.
func terminateProcessTree(cmd *exec.Cmd) error {
return signalProcessTree(cmd, syscall.SIGTERM)
}
// killProcessTree sends SIGKILL to the whole process group, force-terminating
// every process in the tree.
func killProcessTree(cmd *exec.Cmd) error {
return signalProcessTree(cmd, syscall.SIGKILL)
}
// signalProcessTree signals the process group led by cmd.Process. Because the
// child was started with Setpgid it is its own group leader (pgid == pid), so
// targeting -pid reaches the child and every descendant still in the group.
// Falls back to signalling just the child if the group send fails (e.g. the
// group has already drained), so we never silently skip the signal.
func signalProcessTree(cmd *exec.Cmd, sig syscall.Signal) error {
if cmd == nil || cmd.Process == nil {
return nil
}
if err := syscall.Kill(-cmd.Process.Pid, sig); err != nil {
return cmd.Process.Signal(sig)
}
return nil
}
+53
View File
@@ -0,0 +1,53 @@
//go:build windows
package process
import (
"fmt"
"os/exec"
"syscall"
)
// setProcAttributes sets platform-specific process attributes. CREATE_NO_WINDOW
// keeps the upstream from spawning its own console window.
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
}
}
// terminateProcessTree requests a graceful shutdown of the whole process tree
// rooted at cmd.Process. Windows has no SIGTERM or process-group signalling, so
// we shell out to `taskkill /t`, which walks the child tree by PID — the
// equivalent of signalling a Unix process group. Without /f, taskkill asks the
// processes to close rather than force-killing them.
func terminateProcessTree(cmd *exec.Cmd) error {
return taskkillProcessTree(cmd, false)
}
// killProcessTree force-terminates the whole process tree rooted at cmd.Process
// via `taskkill /f /t`, so any descendant that ignored or outlived the graceful
// request is killed alongside the parent rather than leaked as an orphan.
func killProcessTree(cmd *exec.Cmd) error {
return taskkillProcessTree(cmd, true)
}
// taskkillProcessTree runs taskkill against cmd.Process.Pid. The /t flag
// terminates the process together with any child processes it started, which is
// the Windows analogue of signalling a Unix process group via its negative PID.
// When force is true the /f flag force-kills; otherwise taskkill requests a
// graceful close.
func taskkillProcessTree(cmd *exec.Cmd, force bool) error {
if cmd == nil || cmd.Process == nil {
return nil
}
args := make([]string, 0, 4)
if force {
args = append(args, "/f")
}
args = append(args, "/t", "/pid", fmt.Sprintf("%d", cmd.Process.Pid))
kill := exec.Command("taskkill", args...)
setProcAttributes(kill)
return kill.Run()
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !windows
package process
// SetupTreeCleanup is a no-op on non-Windows platforms, where upstream process
// teardown is handled via process-group signalling (see runtime_unix.go).
func SetupTreeCleanup() error { return nil }
+50
View File
@@ -0,0 +1,50 @@
//go:build windows
package process
import (
"fmt"
"unsafe"
"golang.org/x/sys/windows"
)
// SetupTreeCleanup assigns the current process to a Windows Job Object
// configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. Upstream processes
// spawned afterwards are associated with the same job, so when llama-swap exits
// for any reason — graceful shutdown, a forced second Ctrl+C, or a crash — the
// OS terminates the whole job and reaps every child instead of leaving orphans
// behind. It is the parent-side complement to the per-process teardown in
// runtime_windows.go.
//
// The job handle is intentionally leaked for the lifetime of the process: the
// kill-on-close behaviour fires when the last handle is released, which the OS
// does when the process exits.
func SetupTreeCleanup() error {
job, err := windows.CreateJobObject(nil, nil)
if err != nil {
return fmt.Errorf("CreateJobObject: %w", err)
}
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
},
}
if _, err := windows.SetInformationJobObject(
job,
windows.JobObjectExtendedLimitInformation,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info)),
); err != nil {
windows.CloseHandle(job)
return fmt.Errorf("SetInformationJobObject: %w", err)
}
if err := windows.AssignProcessToJobObject(job, windows.CurrentProcess()); err != nil {
windows.CloseHandle(job)
return fmt.Errorf("AssignProcessToJobObject: %w", err)
}
return nil
}
+39
View File
@@ -0,0 +1,39 @@
package ring
type Buffer[T any] struct {
buf []T
head int
size int
}
func NewBuffer[T any](capacity int) Buffer[T] {
if capacity < 1 {
capacity = 1
}
return Buffer[T]{buf: make([]T, capacity)}
}
// Push adds v, overwriting the oldest entry when the buffer is full.
func (r *Buffer[T]) Push(v T) {
cap := len(r.buf)
if r.size < cap {
r.buf[(r.head+r.size)%cap] = v
r.size++
} else {
r.buf[r.head] = v
r.head = (r.head + 1) % cap
}
}
// Slice returns all entries in insertion order as a new slice.
func (r *Buffer[T]) Slice() []T {
if r.size == 0 {
return nil
}
cap := len(r.buf)
result := make([]T, r.size)
for i := 0; i < r.size; i++ {
result[i] = r.buf[(r.head+i)%cap]
}
return result
}
+44
View File
@@ -0,0 +1,44 @@
package ring
import "testing"
const benchCap = 600 // matches default MaxAge/Every (1min / 100ms)
func BenchmarkBuffer_PushNoWrap(b *testing.B) {
for b.Loop() {
buf := NewBuffer[int](b.N + 1)
for i := range b.N {
buf.Push(i)
}
}
}
func BenchmarkBuffer_PushWrap(b *testing.B) {
buf := NewBuffer[int](benchCap)
b.ResetTimer()
for i := range b.N {
buf.Push(i)
}
}
func BenchmarkBuffer_Slice(b *testing.B) {
buf := NewBuffer[int](benchCap)
for i := range benchCap {
buf.Push(i)
}
b.ResetTimer()
for range b.N {
_ = buf.Slice()
}
}
func BenchmarkBuffer_PushAndSlice(b *testing.B) {
buf := NewBuffer[int](benchCap)
b.ResetTimer()
for i := range b.N {
buf.Push(i)
if i%benchCap == 0 {
_ = buf.Slice()
}
}
}
+65
View File
@@ -0,0 +1,65 @@
package ring
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuffer_EmptySliceIsNil(t *testing.T) {
b := NewBuffer[int](4)
assert.Nil(t, b.Slice())
}
func TestBuffer_PushBelowCapacity(t *testing.T) {
b := NewBuffer[int](4)
b.Push(1)
b.Push(2)
assert.Equal(t, []int{1, 2}, b.Slice())
}
func TestBuffer_PushAtCapacity(t *testing.T) {
b := NewBuffer[int](3)
b.Push(1)
b.Push(2)
b.Push(3)
assert.Equal(t, []int{1, 2, 3}, b.Slice())
}
func TestBuffer_PushOverCapacityEvictsOldest(t *testing.T) {
b := NewBuffer[int](3)
b.Push(1)
b.Push(2)
b.Push(3)
b.Push(4)
assert.Equal(t, []int{2, 3, 4}, b.Slice())
}
func TestBuffer_CapacityOne(t *testing.T) {
b := NewBuffer[int](1)
b.Push(1)
b.Push(2)
assert.Equal(t, []int{2}, b.Slice())
}
func TestBuffer_ZeroCapacityDefaultsToOne(t *testing.T) {
b := NewBuffer[int](0)
b.Push(42)
assert.Equal(t, []int{42}, b.Slice())
}
func TestBuffer_SliceReturnsCopy(t *testing.T) {
b := NewBuffer[int](4)
b.Push(10)
s := b.Slice()
s[0] = 99
assert.Equal(t, []int{10}, b.Slice())
}
func TestBuffer_InsertionOrderPreservedAfterWrap(t *testing.T) {
b := NewBuffer[int](4)
for i := 1; i <= 8; i++ {
b.Push(i)
}
assert.Equal(t, []int{5, 6, 7, 8}, b.Slice())
}
+800
View File
@@ -0,0 +1,800 @@
package router
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type shutdownReq struct {
timeout time.Duration
respond chan error
}
type unloadReq struct {
targets []string
timeout time.Duration
respond chan struct{}
}
type handlerReq struct {
model string
ctx context.Context
respond chan handlerResp
positionCh chan int
}
type handlerResp struct {
handleFunc http.HandlerFunc
err error
}
type swapDone struct {
modelID string
err error
}
type serveDoneEvent struct {
modelID string
}
type activeSwap struct {
modelID string
evict []string
waiters []handlerReq
}
// swapPlanner is the only piece of behaviour that differs between concrete
// routers. baseRouter never inspects its internals.
type swapPlanner interface {
// EvictionFor returns running model IDs that must be stopped before
// target can serve. alsoRunning lists models the baseRouter has already
// committed to loading (in-flight swaps) which the planner cannot see
// via process.State() yet. Pure decision; must not log.
EvictionFor(target string, alsoRunning []string) []string
// OnSwapStart runs once at the start of every swap. Planners may log
// their decision here at whatever verbosity they choose.
OnSwapStart(target string)
}
// baseRouter owns the channels, run-loop, and orchestration code shared by
// every concrete router. Concrete routers embed *baseRouter and supply a
// swapPlanner that captures how their eviction set is decided.
type baseRouter struct {
name string
config config.Config
processes map[string]process.Process
logger *logmon.Monitor
planner swapPlanner
// shutdownCtx governs the request machinery: cancelling it tells grant()
// and ServeHTTP to stop granting and reject callers. It is deliberately
// separate from procCtx — see procCtx below.
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
// procCtx is the parent context for every managed process and governs
// process lifetime only. handleShutdown stops processes gracefully via
// Stop() and cancels procCtx afterwards, so teardown is never a context
// cancel racing the graceful path (which collapsed the grace to 100ms and
// let the caller return before children were reaped — see process run loop).
procCtx context.Context
procCancel context.CancelFunc
handlerCh chan handlerReq
shutdownCh chan shutdownReq
unloadCh chan unloadReq
swapDoneCh chan swapDone
serveDoneCh chan serveDoneEvent
runDone chan struct{}
// testProcessed, when non-nil, receives one event after each handlerReq
// or swapDone has been fully processed by run(). Tests use it to wait
// for run() to reach a deterministic state without sleeping. serveDone
// events are intentionally NOT signalled here so test event counts
// remain stable.
testProcessed chan struct{}
}
func newBaseRouter(name string, conf config.Config, processes map[string]process.Process, planner swapPlanner, logger *logmon.Monitor) *baseRouter {
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
procCtx, procCancel := context.WithCancel(context.Background())
return &baseRouter{
name: name,
config: conf,
processes: processes,
logger: logger,
planner: planner,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
procCtx: procCtx,
procCancel: procCancel,
handlerCh: make(chan handlerReq),
shutdownCh: make(chan shutdownReq),
unloadCh: make(chan unloadReq),
swapDoneCh: make(chan swapDone),
serveDoneCh: make(chan serveDoneEvent),
runDone: make(chan struct{}),
}
}
func (b *baseRouter) notifyProcessed() {
if b.testProcessed != nil {
b.testProcessed <- struct{}{}
}
}
func (b *baseRouter) run() {
defer close(b.runDone)
active := make(map[string]*activeSwap)
inFlight := make(map[string]int)
var queued []handlerReq
for {
select {
case req := <-b.shutdownCh:
b.handleShutdown(req, active, queued)
return
case req := <-b.handlerCh:
b.handleRequest(req, active, inFlight, &queued)
b.notifyProcessed()
case req := <-b.unloadCh:
b.handleUnload(req, active, inFlight, &queued)
b.notifyProcessed()
case ev := <-b.swapDoneCh:
b.handleSwapDone(ev, active, inFlight, &queued)
b.notifyProcessed()
case ev := <-b.serveDoneCh:
b.handleServeDone(ev, active, inFlight, &queued)
}
}
}
// grant sends a response back to the caller of ServeHTTP and tells us
// whether the caller was still there to receive it.
//
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
// a select waiting on it. "Unbuffered" is the important word: a send only
// completes when the other side is actively receiving. So if this send
// succeeds, we know for a fact the caller picked up the response and will
// act on it. If the caller has already given up (its request context was
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
// down, the send never lands, one of the other select cases fires, and we
// report back that the grant did NOT happen.
//
// That distinction matters for in-flight bookkeeping — see grantHandler.
func (b *baseRouter) grant(req handlerReq, resp handlerResp) bool {
select {
case req.respond <- resp:
return true
case <-req.ctx.Done():
return false
case <-b.shutdownCtx.Done():
return false
}
}
// grantHandler is the "this caller can now use process p" path. It does
// two things that must stay locked together:
//
// 1. Hand the caller a wrapped p.ServeHTTP (via trackedServe) so when the
// HTTP request finishes, the run loop hears about it.
// 2. Bump inFlight[modelID] so the router knows this process is busy and
// refuses to evict it until the count comes back down.
//
// The increment is gated on grant() returning true. If grant() returns
// false, the caller already walked away and trackedServe will never run —
// which means no matching decrement will ever arrive on serveDoneCh.
// Incrementing in that case would strand the counter at >0 forever and
// the router would never again be willing to swap this model out.
//
// In short: increment if and only if we know a decrement is coming.
func (b *baseRouter) grantHandler(req handlerReq, modelID string, p process.Process, inFlight map[string]int) {
if b.grant(req, handlerResp{handleFunc: b.trackedServe(modelID, p)}) {
inFlight[modelID]++
}
}
// trackedServe is the wrapper that closes the loop on in-flight tracking.
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
// send on serveDoneCh after the handler returns. That send is what tells
// the run loop "this model now has one fewer request in flight — go look
// at the queue again, you may be able to start a swap you previously had
// to defer."
//
// The select on shutdownCtx.Done() is a release valve: if the router is
// already shutting down, nobody is reading serveDoneCh, so we drop the
// notification rather than blocking the HTTP goroutine forever.
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
select {
case b.serveDoneCh <- serveDoneEvent{modelID: modelID}:
case <-b.shutdownCtx.Done():
}
}()
p.ServeHTTP(w, r)
}
}
// handleRequest decides what to do with one incoming ServeHTTP request. It is
// called from run() and never blocks indefinitely: any work that has to wait
// (starting a process, stopping siblings, waiting for ready) is deferred to
// a swap goroutine and reported back via swapDoneCh.
//
// The decision tree, in order:
//
// 1. Unknown model — respond with ErrNoLocalModelFound and move on.
// 2. A swap to the same model is already in flight — attach this waiter so
// one swap serves all callers that asked for the same model.
// 3. Fast path — the target process is already ready, the planner sees
// nothing to evict, and no in-flight swap is evicting it. Hand back its
// ServeHTTP immediately (wrapped so the run loop knows when it ends).
// 4. Would collide with an in-flight swap (we'd stop their target, or
// they're stopping us) — park in the queue for handleSwapDone to drain.
// 5. Would evict a process that is still handling requests — park in the
// queue. handleServeDone will retry when the busy process drains.
// 6. Otherwise — start a new swap. This may run in parallel with other
// active swaps when their evict sets don't intersect.
func (b *baseRouter) handleRequest(req handlerReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
// (1) Unknown model.
p, ok := b.processes[req.model]
if !ok {
b.logger.Debugf("%s: model %s not handled by this router", b.name, req.model)
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
return
}
// (2) Join an in-flight swap for the same model.
if s, ok := active[req.model]; ok {
b.logger.Debugf("%s: joining in-flight swap for model %s (%d waiters)", b.name, req.model, len(s.waiters)+1)
s.waiters = append(s.waiters, req)
return
}
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
// (3) Fast path: ready, nothing to evict, and nobody is evicting us.
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: fast-path serving model %s (already ready)", b.name, req.model)
b.grantHandler(req, req.model, p, inFlight)
return
}
// (4) Collision with an in-flight swap — queue.
if collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: queuing request for model %s (collides with in-flight swap)", b.name, req.model)
*queued = append(*queued, req)
b.broadcastQueuePositions(*queued)
return
}
// (5) Would evict a busy process — queue until it drains.
if conflictsWithInFlight(evict, inFlight) {
b.logger.Debugf("%s: queuing request for model %s (would evict in-flight process)", b.name, req.model)
*queued = append(*queued, req)
b.broadcastQueuePositions(*queued)
return
}
// (6) Start a new (possibly parallel) swap.
b.logger.Debugf("%s: starting swap for model %s, evicting %v", b.name, req.model, evict)
s := b.startSwap(req, evict)
active[s.modelID] = s
}
// handleSwapDone is called from run() when a swap goroutine reports that it
// has finished. It fans out the result to every waiter that joined this swap,
// removes the swap from the active map, and then walks the queue once,
// promoting any items that no longer collide with the remaining active set.
// FIFO order is preserved: items still blocked stay in place.
func (b *baseRouter) handleSwapDone(ev swapDone, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
s, ok := active[ev.modelID]
if !ok {
return
}
delete(active, ev.modelID)
for _, w := range s.waiters {
if ev.err != nil {
b.grant(w, handlerResp{err: ev.err})
} else {
p := b.processes[ev.modelID]
b.grantHandler(w, ev.modelID, p, inFlight)
}
}
b.drainQueue(active, inFlight, queued)
}
// handleServeDone is called from run() each time a tracked ServeHTTP
// finishes. It decrements the per-model in-flight count and, when that
// drops to zero, retries the queue: requests whose swap was deferred
// because they would have evicted this (now-idle) process can now proceed.
func (b *baseRouter) handleServeDone(ev serveDoneEvent, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
inFlight[ev.modelID]--
if inFlight[ev.modelID] <= 0 {
delete(inFlight, ev.modelID)
b.drainQueue(active, inFlight, queued)
}
}
// drainQueue walks the queued requests in order, re-running the handleRequest
// decision tree against the (now smaller) active set. Items that can now start
// or join become satisfied; items still blocked remain queued in original
// order so they get another chance on the next swap completion.
func (b *baseRouter) drainQueue(active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
if len(*queued) == 0 {
return
}
pending := *queued
var remaining []handlerReq
for _, req := range pending {
p, ok := b.processes[req.model]
if !ok {
b.grant(req, handlerResp{err: ErrNoLocalModelFound})
continue
}
if s, ok := active[req.model]; ok {
b.logger.Debugf("%s: queued request for model %s now joining in-flight swap", b.name, req.model)
s.waiters = append(s.waiters, req)
continue
}
evict := b.planner.EvictionFor(req.model, activeTargets(active, req.model))
if p.State() == process.StateReady && len(evict) == 0 && !collidesWith(req.model, evict, active) {
b.logger.Debugf("%s: queued request for model %s now served fast-path", b.name, req.model)
b.grantHandler(req, req.model, p, inFlight)
continue
}
if collidesWith(req.model, evict, active) {
remaining = append(remaining, req)
continue
}
if conflictsWithInFlight(evict, inFlight) {
remaining = append(remaining, req)
continue
}
b.logger.Debugf("%s: queued request for model %s now starting swap, evicting %v", b.name, req.model, evict)
s := b.startSwap(req, evict)
active[s.modelID] = s
}
*queued = remaining
b.broadcastQueuePositions(*queued)
}
// broadcastQueuePositions sends each queued request its current 1-indexed
// position. Sends are non-blocking: if the channel is full, the old value is
// drained first so the consumer always sees the latest position.
func (b *baseRouter) broadcastQueuePositions(queued []handlerReq) {
for i, req := range queued {
pos := i + 1
select {
case req.positionCh <- pos:
default:
select {
case <-req.positionCh:
default:
}
select {
case req.positionCh <- pos:
default:
}
}
}
}
func (b *baseRouter) startSwap(initial handlerReq, evict []string) *activeSwap {
swap := &activeSwap{
modelID: initial.model,
evict: evict,
waiters: []handlerReq{initial},
}
b.planner.OnSwapStart(initial.model)
go b.doSwap(initial.model, evict)
return swap
}
// activeTargets returns the IDs of every in-flight swap target except exclude.
// baseRouter passes this to the planner so eviction decisions account for
// models that have been committed to but have not yet transitioned to
// StateStarting in their process state machine.
func activeTargets(active map[string]*activeSwap, exclude string) []string {
if len(active) == 0 {
return nil
}
out := make([]string, 0, len(active))
for id := range active {
if id == exclude {
continue
}
out = append(out, id)
}
return out
}
// collidesWith reports whether a new swap with this target and evict set can
// safely run alongside the currently active swaps. Same-target callers should
// JOIN (handled before this) — they do not collide with themselves.
func collidesWith(target string, evict []string, active map[string]*activeSwap) bool {
for id, s := range active {
if id == target {
continue
}
if containsString(evict, id) {
return true
}
if containsString(s.evict, target) {
return true
}
}
return false
}
// conflictsWithInFlight reports whether any model in evict is still handling
// requests. Stopping a busy process would cancel its callers' connections,
// so the router defers the swap until those callers finish.
func conflictsWithInFlight(evict []string, inFlight map[string]int) bool {
for _, m := range evict {
if inFlight[m] > 0 {
return true
}
}
return false
}
func containsString(xs []string, s string) bool {
for _, x := range xs {
if x == s {
return true
}
}
return false
}
func (b *baseRouter) doSwap(modelID string, toStop []string) {
timeout := b.healthCheckTimeout()
var wg sync.WaitGroup
for _, mID := range toStop {
wg.Add(1)
go func(p process.Process, id string) {
defer wg.Done()
if err := p.Stop(timeout); err != nil {
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
}
}(b.processes[mID], mID)
}
wg.Wait()
target := b.processes[modelID]
if target.State() == process.StateStopped {
go func() {
if err := target.Run(timeout); err != nil {
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
}
}()
}
err := target.WaitReady(b.shutdownCtx)
select {
case b.swapDoneCh <- swapDone{modelID: modelID, err: err}:
case <-b.shutdownCtx.Done():
}
}
func (b *baseRouter) handleShutdown(req shutdownReq, active map[string]*activeSwap, queued []handlerReq) {
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
// Cancel shutdownCtx first so any waiter that is currently parked on
// its respond channel can exit via its own shutdownCtx.Done() branch.
// The grant calls below then either land (waiter happened to receive
// before noticing shutdown) or fall through immediately via grant's
// shutdownCtx case — either way the waiter sees a non-OK response.
// This does NOT touch processes: their lifetime is procCtx, cancelled
// only after the graceful Stop() calls below have reaped them.
b.shutdownFn()
for _, s := range active {
for _, w := range s.waiters {
b.grant(w, handlerResp{err: shutdownErr})
}
}
for _, w := range queued {
b.grant(w, handlerResp{err: shutdownErr})
}
stopTimeout := req.timeout
if stopTimeout <= 0 {
stopTimeout = b.healthCheckTimeout()
}
var wg sync.WaitGroup
for i, p := range b.processes {
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(stopTimeout); err != nil {
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
}
}(i, p)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
if req.timeout > 0 {
select {
case <-done:
case <-time.After(req.timeout):
<-done
}
} else {
<-done
}
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
// the process run-loop goroutines exit; they are already StateStopped, so
// this is a clean no-op kill rather than a forced teardown.
b.procCancel()
req.respond <- nil
}
func (b *baseRouter) healthCheckTimeout() time.Duration {
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
if t <= 0 {
return 30 * time.Second
}
return t
}
func (b *baseRouter) Handles(model string) bool {
_, ok := b.processes[model]
return ok
}
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
if p, ok := b.processes[modelID]; ok {
return p.Logger(), true
}
return nil, false
}
// RunningModels returns the current state of every process that is not stopped
// or shut down. The processes map keys are fixed at construction and State()
// is a snapshot, so this is safe to call without the run loop.
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
running := make(map[string]process.ProcessState)
for id, p := range b.processes {
st := p.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
running[id] = st
}
return running
}
// Unload stops the named models, or every running model when none are named.
// It blocks until each targeted process has stopped.
//
// The request is funneled through the run loop so eviction is coordinated
// with the rest of the router's state: pending swap waiters for an
// unloaded model are released with an error, queued requests for unloaded
// models are dropped, and any deferred swaps that were waiting on those
// models become eligible to start.
//
// In-flight requests being served by an unloaded process are not waited
// for — Stop kills the upstream, those callers see whatever error the
// reverse proxy surfaces and may retry. Their trackedServe defers fire
// normally and decrement inFlight as the dying handlers return.
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
targets := models
if len(targets) == 0 {
targets = make([]string, 0, len(b.processes))
for id := range b.processes {
targets = append(targets, id)
}
}
if len(targets) == 0 {
return
}
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
select {
case b.unloadCh <- req:
case <-b.runDone:
return
}
<-req.respond
}
// handleUnload runs on the run loop in response to an Unload call. It
// reconciles router-owned state with the impending Stop, then performs
// the Stop synchronously so callers of Unload remain blocked until each
// targeted process has actually exited.
func (b *baseRouter) handleUnload(req unloadReq, active map[string]*activeSwap, inFlight map[string]int, queued *[]handlerReq) {
unloadErr := fmt.Errorf("%s: model unloaded", b.name)
targetSet := make(map[string]bool, len(req.targets))
for _, id := range req.targets {
targetSet[id] = true
}
// Release waiters of any in-flight swap whose target is being
// unloaded. The swap goroutine itself is left to finish on its own;
// when its swapDone arrives, handleSwapDone will find no entry in
// active and silently drop it.
for id := range targetSet {
s, ok := active[id]
if !ok {
continue
}
for _, w := range s.waiters {
b.grant(w, handlerResp{err: unloadErr})
}
delete(active, id)
}
// Drop queued requests addressed to unloaded models. Requests for
// other models stay queued and may benefit from drainQueue at the end.
if len(*queued) > 0 {
kept := (*queued)[:0]
for _, w := range *queued {
if targetSet[w.model] {
b.grant(w, handlerResp{err: unloadErr})
continue
}
kept = append(kept, w)
}
*queued = kept
}
// Stop the targeted processes. Done synchronously so Unload's caller
// can rely on "after Unload returns, the process is stopped". inFlight
// is intentionally NOT cleared here: each dying handler will fire its
// trackedServe defer and reach handleServeDone in the normal way once
// the run loop is free again.
var wg sync.WaitGroup
for id := range targetSet {
p, ok := b.processes[id]
if !ok {
continue
}
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(req.timeout); err != nil {
b.logger.Warnf("%s: unloading %s failed: %v", b.name, id, err)
}
}(id, p)
}
wg.Wait()
// Removing entries from active above may have unblocked queued
// requests that previously collided with the now-cancelled swaps.
b.drainQueue(active, inFlight, queued)
close(req.respond)
}
func (b *baseRouter) Shutdown(timeout time.Duration) error {
if !b.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("%s shutdown already in progress", b.name)
}
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
select {
case b.shutdownCh <- req:
case <-b.runDone:
return nil
}
return <-req.respond
}
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if b.shuttingDown.Load() {
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
data, err := FetchContext(req, b.config)
if err != nil {
SendError(w, req, err)
return
}
hr := handlerReq{
model: data.ModelID,
ctx: req.Context(),
// Unbuffered: a successful send on respond proves the waiter is
// alive and consuming. grant() relies on this to avoid handing a
// handleFunc to a cancelled waiter and leaking the inFlight count.
respond: make(chan handlerResp),
positionCh: make(chan int, 1),
}
select {
case b.handlerCh <- hr:
case <-req.Context().Done():
return
case <-b.shutdownCtx.Done():
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
isModelReady := false
if p, ok := b.processes[data.ModelID]; ok {
isModelReady = p.State() == process.StateReady
}
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
var lw *loadingWriter
cancelLoad := func() {}
if shouldShowLoading {
var swapCtx context.Context
swapCtx, cancelLoad = context.WithCancel(req.Context())
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
go lw.start(swapCtx)
go func() {
for {
select {
case pos := <-hr.positionCh:
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
case <-swapCtx.Done():
return
}
}
}()
}
// finishLoading stops the loading stream and fences its goroutine off from
// the ResponseWriter before the real handler (or ServeHTTP's return)
// reclaims it. release() must run even when waitForCompletion times out:
// otherwise a still-streaming goroutine flushes a finalized response and
// panics on the recycled *bufio.Writer.
finishLoading := func() {
cancelLoad()
if lw != nil {
lw.waitForCompletion(1 * time.Second)
lw.release()
}
}
var resp handlerResp
select {
case resp = <-hr.respond:
finishLoading()
case <-req.Context().Done():
finishLoading()
return
case <-b.shutdownCtx.Done():
finishLoading()
SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
if resp.err != nil {
SendError(w, req, resp.err)
return
}
resp.handleFunc(w, req)
}
+863
View File
@@ -0,0 +1,863 @@
package router
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// stubPlanner is a swapPlanner that returns a fixed eviction list per target
// and never logs. It lets the base-router tests cover shared run-loop
// behaviour without dragging in either real router's eviction rules.
type stubPlanner struct {
evict map[string][]string
}
func (s *stubPlanner) EvictionFor(target string, _ []string) []string {
if s.evict == nil {
return nil
}
return s.evict[target]
}
func (s *stubPlanner) OnSwapStart(string) {}
func newTestBase(t *testing.T, processes map[string]process.Process, planner swapPlanner) *baseRouter {
t.Helper()
conf := config.Config{HealthCheckTimeout: 5}
b := newBaseRouter("test", conf, processes, planner, logmon.NewWriter(io.Discard))
b.testProcessed = make(chan struct{}, 64)
go b.run()
t.Cleanup(func() {
if !b.shuttingDown.Load() {
_ = b.Shutdown(time.Second)
}
})
return b
}
func TestBaseRouter_RunningModels(t *testing.T) {
ready := newFakeProcess("ready")
ready.markReady()
starting := newFakeProcess("starting")
starting.setState(process.StateStarting)
stopped := newFakeProcess("stopped")
b := newTestBase(t, map[string]process.Process{
"ready": ready, "starting": starting, "stopped": stopped,
}, &stubPlanner{})
running := b.RunningModels()
if len(running) != 2 {
t.Fatalf("running=%v want 2 entries", running)
}
if running["ready"] != process.StateReady {
t.Errorf("ready state=%q want ready", running["ready"])
}
if running["starting"] != process.StateStarting {
t.Errorf("starting state=%q want starting", running["starting"])
}
if _, ok := running["stopped"]; ok {
t.Errorf("stopped process should be excluded from RunningModels")
}
}
func TestBaseRouter_UnloadAll(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
c := newFakeProcess("c")
c.markReady()
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
b.Unload(time.Second)
if a.State() != process.StateStopped || c.State() != process.StateStopped {
t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State())
}
}
func TestBaseRouter_UnloadSpecificModel(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
c := newFakeProcess("c")
c.markReady()
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
b.Unload(time.Second, "a")
if a.State() != process.StateStopped {
t.Errorf("a should be stopped, got %q", a.State())
}
if c.State() != process.StateReady {
t.Errorf("c should remain ready, got %q", c.State())
}
}
// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its
// Stop calls concurrently rather than stopping each process serially. Each
// fakeProcess.Stop is pinned via stopBlock; the test only releases them
// after observing every stopStarted, proving all three Stops were in
// flight simultaneously.
func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
a.stopBlock = make(chan struct{})
pb := newFakeProcess("b")
pb.markReady()
pb.stopBlock = make(chan struct{})
pc := newFakeProcess("c")
pc.markReady()
pc.stopBlock = make(chan struct{})
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{})
unloadDone := make(chan struct{})
go func() {
b.Unload(time.Second, "a", "b", "c")
close(unloadDone)
}()
// All three Stop calls must start before any of them are allowed to
// complete. If Unload was serial, only one stopStarted would fire
// until we released its stopBlock, and this would deadlock.
for _, p := range []*fakeProcess{a, pb, pc} {
select {
case <-p.stopStarted:
case <-time.After(2 * time.Second):
t.Fatalf("Stop on %s never started — Unload is not parallel", p.id)
}
}
// Release them; Unload should now return.
close(a.stopBlock)
close(pb.stopBlock)
close(pc.stopBlock)
select {
case <-unloadDone:
case <-time.After(2 * time.Second):
t.Fatal("Unload did not return after stops released")
}
for _, p := range []*fakeProcess{a, pb, pc} {
if p.State() != process.StateStopped {
t.Errorf("%s state=%q want stopped", p.id, p.State())
}
if got := p.stopCalls.Load(); got != 1 {
t.Errorf("%s stopCalls=%d want 1", p.id, got)
}
}
}
// TestBaseRouter_Unload_ReleasesActiveSwapWaiters verifies that Unload
// rejoins router state: a request whose swap to the unloaded model is
// still in progress receives an error, instead of being abandoned
// against a process that's about to vanish.
func TestBaseRouter_Unload_ReleasesActiveSwapWaiters(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false: the swap parks on WaitReady so we can interrupt
// it with Unload before it completes.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
b.ServeHTTP(w, newRequest("a"))
close(done)
}()
waitProcessed(t, b.testProcessed, 1) // handlerReq absorbed; swap started
<-a.runStarted
b.Unload(time.Second, "a")
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("ServeHTTP did not return after Unload")
}
if w.Code == http.StatusOK {
t.Errorf("expected non-OK status after Unload, got %d body=%q", w.Code, w.Body.String())
}
if a.State() != process.StateStopped {
t.Errorf("a state=%q want stopped", a.State())
}
}
// TestBaseRouter_Unload_DropsQueuedRequests verifies that queued requests
// for an unloaded model receive an error rather than sitting forever in
// the queue against state the router no longer maintains.
func TestBaseRouter_Unload_DropsQueuedRequests(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
// Loading B evicts A — so a request for B while A is loading queues.
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
// r1 starts the swap to A and parks on WaitReady (autoReady=false).
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
// r2 for B collides with A's in-flight swap and queues.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
// Unload B — r2 (queued, targeting B) must be released with an error.
b.Unload(time.Second, "b")
select {
case <-done2:
case <-time.After(2 * time.Second):
t.Fatal("queued B request did not return after Unload(b)")
}
if w2.Code == http.StatusOK {
t.Errorf("queued B request: expected non-OK status, got %d", w2.Code)
}
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b.runCalls=%d want 0 (B should never have been started)", got)
}
// Release r1 so the test cleans up cleanly.
a.markReady()
select {
case <-done1:
case <-time.After(2 * time.Second):
t.Fatal("r1 did not complete after a.markReady")
}
}
func TestBaseRouter_FastPath(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.serveCalls.Load(); got != 1 {
t.Errorf("serveCalls=%d want 1", got)
}
if got := a.runCalls.Load(); got != 0 {
t.Errorf("runCalls=%d want 0 (fast path should not start)", got)
}
}
func TestBaseRouter_OnDemandStart(t *testing.T) {
a := newFakeProcess("a")
a.autoReady = true
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.runCalls.Load(); got != 1 {
t.Errorf("runCalls=%d want 1", got)
}
if got := a.serveCalls.Load(); got != 1 {
t.Errorf("serveCalls=%d want 1", got)
}
}
func TestBaseRouter_ConcurrentSameModel(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false so the swap parks on WaitReady until we release it.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
const N = 5
var wg sync.WaitGroup
codes := make([]int, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
codes[i] = w.Code
}(i)
}
waitProcessed(t, b.testProcessed, N) // all N handlerReqs absorbed by run()
<-a.runStarted // swap goroutine reached Run()
a.markReady()
wg.Wait()
for i, c := range codes {
if c != http.StatusOK {
t.Errorf("request %d: status=%d", i, c)
}
}
if got := a.runCalls.Load(); got != 1 {
t.Errorf("runCalls=%d want 1 (single swap should issue one Run)", got)
}
if got := a.serveCalls.Load(); got != N {
t.Errorf("serveCalls=%d want %d", got, N)
}
}
func TestBaseRouter_ContextCancel(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false so swap parks forever until we mark ready.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
ctx, cancel := context.WithCancel(context.Background())
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequestCtx(ctx, "a"))
close(done1)
}()
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("a"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap
<-a.runStarted
cancel()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("cancelled ServeHTTP did not return after ctx cancel")
}
a.markReady()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("non-cancelled ServeHTTP did not complete after swap")
}
if w2.Code != http.StatusOK {
t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String())
}
}
func TestBaseRouter_QueuedDifferentModel(t *testing.T) {
a := newFakeProcess("a")
pa := newFakeProcess("b")
// Loading b must stop a.
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pa}, planner)
// First request starts a swap to A; A's autoReady=false so it parks.
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
// Second request for B should queue while A's swap is in flight.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
if got := pa.runCalls.Load(); got != 0 {
t.Errorf("b started early: runCalls=%d want 0 while A's swap is pending", got)
}
// Release A's swap. B's swap should then run.
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone for A → B's swap kicked off
<-pa.runStarted
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("A request did not complete")
}
pa.markReady()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("queued B request did not complete after A's swap")
}
if w2.Code != http.StatusOK {
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (B's swap must stop A)", got)
}
}
// TestBaseRouter_QueueCollation verifies that incoming requests of the form
// a, b, c, a, b, c collapse into three swaps (one per model) and that the
// second request for each model rides the fast path — either joining the
// active swap, or being pulled out of the queue when handleSwapDone promotes
// the next model.
func TestBaseRouter_QueueCollation(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
pc := newFakeProcess("c")
// Each model evicts the other two so all swaps are mutually exclusive.
planner := &stubPlanner{evict: map[string][]string{
"a": {"b", "c"},
"b": {"a", "c"},
"c": {"a", "b"},
}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
var (
completedMu sync.Mutex
completed []string
)
record := func(id string) {
completedMu.Lock()
defer completedMu.Unlock()
completed = append(completed, id)
}
ids := []string{"a", "b", "c", "a", "b", "c"}
var wg sync.WaitGroup
for _, id := range ids {
id := id
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest(id))
if w.Code != http.StatusOK {
t.Errorf("%s: status=%d body=%q", id, w.Code, w.Body.String())
return
}
record(id)
}()
// Wait for run() to absorb this request before launching the next,
// so handlerCh receives them in launch order.
waitProcessed(t, b.testProcessed, 1)
}
// All 6 are now parked in run()'s waiters/queue. Release each swap in
// sequence, waiting deterministically for each promotion to fire.
<-a.runStarted
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone(a) → b swap kicked off
<-pb.runStarted
pb.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone(b) → c swap kicked off
<-pc.runStarted
pc.markReady()
wg.Wait()
if got := len(completed); got != 6 {
t.Fatalf("completed=%v want 6", completed)
}
// run() fans out responses in model-grouped order (a1,a2 → b1,b2 → c1,c2)
// but waiter goroutines may be scheduled in any order after their respond
// channel fires, so completion order isn't deterministic. Per-model counts
// (combined with the runCalls checks below) are sufficient to prove queue
// collation collapsed each pair into a single swap.
aDone, bDone, cDone := 0, 0, 0
for _, id := range completed {
switch id {
case "a":
aDone++
case "b":
bDone++
case "c":
cDone++
}
}
if aDone != 2 || bDone != 2 || cDone != 2 {
t.Errorf("per-model counts: a=%d b=%d c=%d, want 2 each (order=%v)", aDone, bDone, cDone, completed)
}
// Single swap per model — the second request for each must have ridden
// the fast path (joined active swap or joined a queued sibling), not
// triggered an extra Run.
if got := a.runCalls.Load(); got != 1 {
t.Errorf("a.runCalls=%d want 1", got)
}
if got := pb.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
if got := pc.runCalls.Load(); got != 1 {
t.Errorf("c.runCalls=%d want 1", got)
}
}
// TestBaseRouter_ConcurrentDisjointSwaps verifies that two requests with
// non-conflicting evict sets are loaded in parallel: both Run() calls happen
// before either process is marked ready.
func TestBaseRouter_ConcurrentDisjointSwaps(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
// Empty evict sets for both: they can load in parallel.
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
// Both swaps must have reached Run() before either is marked ready —
// proves they ran in parallel rather than serializing.
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("request A did not complete")
}
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("request B did not complete")
}
if w1.Code != http.StatusOK {
t.Errorf("A status=%d body=%q", w1.Code, w1.Body.String())
}
if w2.Code != http.StatusOK {
t.Errorf("B status=%d body=%q", w2.Code, w2.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (parallel swap, no eviction)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (parallel swap, no eviction)", got)
}
}
// TestBaseRouter_QueueDrainPromotesMultiple verifies that completing one swap
// unblocks every queued request that no longer collides — they all start in
// parallel rather than one-per-completion.
func TestBaseRouter_QueueDrainPromotesMultiple(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
pc := newFakeProcess("c")
// A's swap evicts both B and C, so B and C must queue. Once A finishes
// B and C themselves have empty evict sets, so they can start together.
planner := &stubPlanner{evict: map[string][]string{
"a": {"b", "c"},
}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
// B and C arrive while A is loading. evict_b and evict_c are empty,
// but collidesWith returns true because they appear in A's evict set.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
w3 := httptest.NewRecorder()
done3 := make(chan struct{})
go func() {
b.ServeHTTP(w3, newRequest("c"))
close(done3)
}()
waitProcessed(t, b.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started early: runCalls=%d", got)
}
if got := pc.runCalls.Load(); got != 0 {
t.Errorf("c started early: runCalls=%d", got)
}
// Release A. The swapDone handler should drain the queue and start
// both B and C in parallel.
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone(A) → drainQueue starts B and C
<-pb.runStarted
<-pc.runStarted
pb.markReady()
pc.markReady()
for i, ch := range []chan struct{}{done1, done2, done3} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
}
// TestBaseRouter_Shutdown_FailsAllInFlight verifies that shutdown returns
// the shutdown error to every waiter on every active swap AND to every
// queued request.
func TestBaseRouter_Shutdown_FailsAllInFlight(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
pc := newFakeProcess("c")
// a and b load in parallel (empty evicts). c collides with both.
planner := &stubPlanner{evict: map[string][]string{
"c": {"a", "b"},
}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, planner)
const waitersPer = 2
var wg sync.WaitGroup
codes := make([]int, 0, 2*waitersPer+1)
var codesMu sync.Mutex
record := func(code int) {
codesMu.Lock()
codes = append(codes, code)
codesMu.Unlock()
}
launch := func(model string) {
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest(model))
record(w.Code)
}()
}
// Active swaps for a and b, each with 2 waiters.
for i := 0; i < waitersPer; i++ {
launch("a")
waitProcessed(t, b.testProcessed, 1)
}
for i := 0; i < waitersPer; i++ {
launch("b")
waitProcessed(t, b.testProcessed, 1)
}
// c collides with both → queues.
launch("c")
waitProcessed(t, b.testProcessed, 1)
<-a.runStarted
<-pb.runStarted
if err := b.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
wg.Wait()
codesMu.Lock()
defer codesMu.Unlock()
if len(codes) != 2*waitersPer+1 {
t.Fatalf("got %d responses, want %d", len(codes), 2*waitersPer+1)
}
for i, c := range codes {
if c == http.StatusOK {
t.Errorf("response %d: status=%d, want non-200 (shutdown)", i, c)
}
}
}
// TestBaseRouter_NoSwapWhileServing verifies that an already-loaded model
// is not stopped to satisfy another model's swap while it is still handling
// a request.
//
// Sequence:
// 1. r1 (A) — A loads; ServeHTTP enters and is pinned via serveBlock.
// 2. r2 (B, planner: B evicts A) — must NOT cause A.Stop while r1 is live.
// 3. r3 (A) — arrives next; the existing code queues it because B's swap
// intent collides with A.
// 4. r1 released — A finishes r1, then r3 is served by A.
// 5. B's swap then proceeds; r2 is served by B.
//
// fakeProcess.stoppedWhileServing flips true if Stop is ever called while
// a ServeHTTP is in flight — a direct, race-free signal of the violation.
func TestBaseRouter_NoSwapWhileServing(t *testing.T) {
a := newFakeProcess("a")
// autoReady left false: we markReady manually after observing runStarted,
// so autoReady's setState(Ready) cannot race with a later Stop and leave
// A in Ready, masking the bug.
a.serveBlock = make(chan struct{})
pb := newFakeProcess("b")
// Same reasoning for B: park its swap on WaitReady until we choose.
planner := &stubPlanner{evict: map[string][]string{"b": {"a"}}}
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, planner)
// r1 — load A and enter its ServeHTTP (which blocks on serveBlock).
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, b.testProcessed, 1) // handlerReq for r1
<-a.runStarted
a.markReady()
waitProcessed(t, b.testProcessed, 1) // swapDone for A
<-a.serveStarted
// r2 — would evict A. A must not be stopped while r1 is in flight.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 1)
// r3 — another request for A, arrives behind r2 and queues because
// B's swap intent (which evicts A) is recorded as active.
w3 := httptest.NewRecorder()
done3 := make(chan struct{})
go func() {
b.ServeHTTP(w3, newRequest("a"))
close(done3)
}()
waitProcessed(t, b.testProcessed, 1)
// Release r1 (and r3 if it is fast-pathed onto the still-loaded A).
// The router must hold off B's swap until A has drained.
close(a.serveBlock)
select {
case <-done1:
case <-time.After(2 * time.Second):
t.Fatal("r1 did not complete after serveBlock release")
}
// Wait for B.Run before marking it ready: markReady before Run would
// skip the Run path entirely and leave pb.runCalls at 0. In a correct
// implementation B's swap only starts after A has drained; in the
// current implementation it has already started — either way runStarted
// fires.
<-pb.runStarted
pb.markReady()
select {
case <-done2:
case <-time.After(2 * time.Second):
t.Fatal("r2 did not complete after B marked ready")
}
select {
case <-done3:
case <-time.After(2 * time.Second):
t.Fatal("r3 did not complete")
}
if w1.Code != http.StatusOK || w2.Code != http.StatusOK || w3.Code != http.StatusOK {
t.Fatalf("statuses: w1=%d w2=%d w3=%d", w1.Code, w2.Code, w3.Code)
}
if w1.Body.String() != "ok:a" {
t.Errorf("r1 body=%q want ok:a", w1.Body.String())
}
if w3.Body.String() != "ok:a" {
t.Errorf("r3 body=%q want ok:a (r3 must be served by A)", w3.Body.String())
}
if w2.Body.String() != "ok:b" {
t.Errorf("r2 body=%q want ok:b", w2.Body.String())
}
if a.stoppedWhileServing.Load() {
t.Errorf("A.Stop was called while A was still handling a request — the router swapped out a busy process")
}
}
func TestBaseRouter_ModelNotFound(t *testing.T) {
a := newFakeProcess("a")
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("unknown"))
if w.Code != http.StatusNotFound {
t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String())
}
}
func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
pb := newFakeProcess("b")
pb.markReady()
go pb.Run(0)
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
if err := b.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := pb.stopCalls.Load(); got != 1 {
t.Errorf("b.stopCalls=%d want 1", got)
}
// Subsequent ServeHTTP should report 5xx.
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable {
t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String())
}
// Second Shutdown should report already in progress.
if err := b.Shutdown(0); err == nil {
t.Errorf("second Shutdown returned nil, want error")
}
}
+112
View File
@@ -0,0 +1,112 @@
package router
import (
"fmt"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type Group struct {
*baseRouter
}
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
modelToGroup := make(map[string]string)
for gid, gcfg := range conf.Groups {
for _, mid := range gcfg.Members {
if existing, dup := modelToGroup[mid]; dup {
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
}
modelToGroup[mid] = gid
}
}
planner := &groupPlanner{
config: conf,
modelToGroup: modelToGroup,
}
processes := make(map[string]process.Process, len(modelToGroup))
base := newBaseRouter("group", conf, processes, planner, proxylog)
planner.processes = processes
for mid := range modelToGroup {
modelCfg, _, ok := conf.FindConfig(mid)
if !ok {
base.shutdownFn()
base.procCancel()
return nil, fmt.Errorf("no model config for %q", mid)
}
procLog := logmon.NewWriter(upstreamlog)
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
if err != nil {
base.shutdownFn()
base.procCancel()
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
}
processes[mid] = p
}
g := &Group{baseRouter: base}
go base.run()
return g, nil
}
// groupPlanner decides evictions from static group configuration.
//
// Same-group siblings are stopped when the group has swap=true. Cross-group
// members are stopped only when the target's group is exclusive; loading a
// model from a non-exclusive group leaves running exclusive groups alone,
// matching the gotcha in the original ProcessGroup behaviour.
type groupPlanner struct {
config config.Config
modelToGroup map[string]string
processes map[string]process.Process
}
func (p *groupPlanner) EvictionFor(target string, alsoRunning []string) []string {
tg := p.modelToGroup[target]
tgCfg := p.config.Groups[tg]
seen := make(map[string]struct{})
var result []string
consider := func(mID string) {
if mID == target {
return
}
if _, dup := seen[mID]; dup {
return
}
og := p.modelToGroup[mID]
switch {
case og == tg && tgCfg.Swap:
seen[mID] = struct{}{}
result = append(result, mID)
// the previous ProcessGroup behaviour did not unload exclusive groups
// when loading a non-exclusive model. This maintains that gotcha
// for backwards compatibility. The newer swap matrix approach does not
// have this issue.
case og != tg && tgCfg.Exclusive:
if ogCfg := p.config.Groups[og]; !ogCfg.Persistent {
seen[mID] = struct{}{}
result = append(result, mID)
}
}
}
for mID, proc := range p.processes {
st := proc.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
consider(mID)
}
for _, mID := range alsoRunning {
consider(mID)
}
return result
}
func (p *groupPlanner) OnSwapStart(target string) {}
+331
View File
@@ -0,0 +1,331 @@
package router
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// newTestGroup builds a Group directly from the supplied processes and config,
// bypassing NewGroup's call to process.New.
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
t.Helper()
modelToGroup := make(map[string]string)
for gid, gcfg := range conf.Groups {
for _, mid := range gcfg.Members {
modelToGroup[mid] = gid
}
}
planner := &groupPlanner{
config: conf,
modelToGroup: modelToGroup,
processes: processes,
}
base := newBaseRouter("group", conf, processes, planner, logmon.NewWriter(io.Discard))
base.testProcessed = make(chan struct{}, 64)
g := &Group{baseRouter: base}
go base.run()
t.Cleanup(func() {
if !g.shuttingDown.Load() {
_ = g.Shutdown(time.Second)
}
})
return g
}
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
conf := config.Config{
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Members: []string{"a"}},
"g2": {Swap: true, Members: []string{"a"}},
},
Models: map[string]config.ModelConfig{
"a": {},
},
}
log := logmon.NewWriter(io.Discard)
if _, err := NewGroup(conf, log, log); err == nil {
t.Fatalf("expected error for duplicate membership")
}
}
func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0) // park a Run goroutine so Stop has something to release
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
if got := b.serveCalls.Load(); got != 1 {
t.Errorf("b.serveCalls=%d want 1", got)
}
}
func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
func TestGroup_CrossGroupExclusive(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got)
}
}
// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for
// models in distinct non-exclusive groups load in parallel rather than
// serializing through the router's run loop.
func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
g.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, g.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
g.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, g.testProcessed, 1)
// Both groups load concurrently — both must reach Run() before either is
// marked ready. If the router still serialised, only one would proceed.
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
}
}
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
// (Swap=true) serialise even when both arrive while neither has reached
// StateStarting yet — the alsoRunning hint to the planner closes that race.
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
g.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, g.testProcessed, 1)
// Request B arrives before A transitions to StateStarting in the process
// state machine. Without the alsoRunning hint, the planner would not see
// A as running, and B would start in parallel, violating Swap=true.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
g.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, g.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started in parallel: runCalls=%d want 0", got)
}
<-a.runStarted
a.markReady()
waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted
<-pb.runStarted
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
}
}
// TestGroup_PersistentNotEvicted verifies that a group with persistent=true
// is never evicted when another exclusive group starts loading. The running
// model in the persistent group stays alive alongside the new one.
func TestGroup_PersistentNotEvicted(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got)
}
if a.State() != process.StateStarting && a.State() != process.StateReady {
t.Errorf("a state=%s want still running", a.State())
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible
// gotcha from the original ProcessGroup: when a model in a non-exclusive group
// is loaded, any running exclusive group keeps running. The two coexist.
func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Groups: map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
},
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got)
}
if a.State() != process.StateStarting && a.State() != process.StateReady {
t.Errorf("a state=%s want still running", a.State())
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
+205
View File
@@ -0,0 +1,205 @@
package router
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// fakeProcess is an in-memory implementation of process.Process used to drive
// the routers through their state machine without spawning real upstreams.
type fakeProcess struct {
id string
mu sync.Mutex
state process.ProcessState
readyCh chan struct{}
stopCh chan struct{}
runStarted chan struct{} // closed on the first Run call
stopStarted chan struct{} // closed on the first Stop call
autoReady bool
// serveBlock, when non-nil, makes ServeHTTP receive from it before
// writing its response. Tests use this to hold a request in-flight.
// Closing the channel releases every blocked ServeHTTP caller.
serveBlock chan struct{}
// serveStarted is closed on the first ServeHTTP entry, letting tests
// wait deterministically for the handler to begin executing.
serveStarted chan struct{}
// stopBlock, when non-nil, makes Stop receive from it (after signalling
// stopStarted) before completing. Tests use this to prove that several
// Stop calls can be in flight simultaneously.
stopBlock chan struct{}
runCalls atomic.Int32
stopCalls atomic.Int32
serveCalls atomic.Int32
// inFlightServe counts ServeHTTP calls currently inside the handler.
// stoppedWhileServing flips true if Stop is ever called while that
// counter is non-zero — a direct, race-free observation of the
// "swap mid-request" anti-property.
inFlightServe atomic.Int32
stoppedWhileServing atomic.Bool
}
func newFakeProcess(id string) *fakeProcess {
return &fakeProcess{
id: id,
state: process.StateStopped,
readyCh: make(chan struct{}),
stopCh: make(chan struct{}),
runStarted: make(chan struct{}),
stopStarted: make(chan struct{}),
serveStarted: make(chan struct{}),
}
}
func (f *fakeProcess) setState(s process.ProcessState) {
f.mu.Lock()
defer f.mu.Unlock()
f.state = s
if s == process.StateReady {
select {
case <-f.readyCh:
default:
close(f.readyCh)
}
}
}
func (f *fakeProcess) State() process.ProcessState {
f.mu.Lock()
defer f.mu.Unlock()
return f.state
}
func (f *fakeProcess) markReady() { f.setState(process.StateReady) }
func (f *fakeProcess) Run(_ time.Duration) error {
f.runCalls.Add(1)
f.mu.Lock()
if f.state != process.StateStopped {
s := f.state
f.mu.Unlock()
return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s)
}
f.state = process.StateStarting
sc := f.stopCh
select {
case <-f.runStarted:
default:
close(f.runStarted)
}
f.mu.Unlock()
if f.autoReady {
f.setState(process.StateReady)
}
<-sc
return nil
}
func (f *fakeProcess) Stop(_ time.Duration) error {
f.stopCalls.Add(1)
if f.inFlightServe.Load() > 0 {
f.stoppedWhileServing.Store(true)
}
f.mu.Lock()
select {
case <-f.stopStarted:
default:
close(f.stopStarted)
}
f.mu.Unlock()
// Test hook: hold Stop here so the test can prove multiple Stops are
// in flight at the same time before any of them complete.
if f.stopBlock != nil {
<-f.stopBlock
}
f.mu.Lock()
defer f.mu.Unlock()
if f.state == process.StateStopped {
return nil
}
f.state = process.StateStopped
select {
case <-f.stopCh:
default:
close(f.stopCh)
}
return nil
}
func (f *fakeProcess) WaitReady(ctx context.Context) error {
f.mu.Lock()
if f.state == process.StateReady {
f.mu.Unlock()
return nil
}
rc := f.readyCh
f.mu.Unlock()
select {
case <-rc:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) }
func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
f.serveCalls.Add(1)
f.inFlightServe.Add(1)
defer f.inFlightServe.Add(-1)
f.mu.Lock()
select {
case <-f.serveStarted:
default:
close(f.serveStarted)
}
f.mu.Unlock()
if f.serveBlock != nil {
<-f.serveBlock
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "ok:%s", f.id)
}
// waitProcessed drains n events from ch, fataling on timeout. One event fires
// per handlerReq or swapDone fully absorbed by run().
func waitProcessed(t *testing.T, ch chan struct{}, n int) {
t.Helper()
for i := 0; i < n; i++ {
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatalf("waitProcessed: only %d/%d events received", i, n)
}
}
}
func newRequest(model string) *http.Request {
body := fmt.Sprintf(`{"model":%q}`, model)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
return r
}
func newRequestCtx(ctx context.Context, model string) *http.Request {
return newRequest(model).WithContext(ctx)
}
+277
View File
@@ -0,0 +1,277 @@
package router
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
var loadingPaths = []string{
"/v1/chat/completions",
}
func isLoadingPath(path string) bool {
for _, p := range loadingPaths {
if strings.HasPrefix(path, p) {
return true
}
}
return false
}
type loadingWriter struct {
hasWritten bool
writer http.ResponseWriter
req *http.Request
ctx context.Context
logger *logmon.Monitor
modelName string
startTime time.Time
pendingMu sync.Mutex
pendingUpdate string
// writeMu serializes writes to the underlying writer and guards released.
// Once released is set, the streaming goroutine must not touch the writer
// again — ServeHTTP has reclaimed it (to run the real handler or to return)
// and writing/flushing a finalized response panics.
writeMu sync.Mutex
released bool
// closed by start when the goroutine finishes (after cleanup messages)
done chan struct{}
// test-only: closed when start enters its loop
loopStarted chan struct{}
// test-only: override the 1s tick interval
tickDuration time.Duration
// test-only: override character streaming speed (0 = no delay)
charPerSecond float64
}
func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter {
s := &loadingWriter{
writer: w,
req: req,
ctx: req.Context(),
logger: logger,
modelName: modelName,
startTime: time.Now(),
tickDuration: 750 * time.Millisecond,
charPerSecond: 75,
}
s.Header().Set("Content-Type", "text/event-stream")
s.Header().Set("Cache-Control", "no-cache")
s.Header().Set("Connection", "keep-alive")
s.WriteHeader(http.StatusOK)
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName))
return s
}
func (s *loadingWriter) setUpdate(msg string) {
s.pendingMu.Lock()
s.pendingUpdate = msg
s.pendingMu.Unlock()
}
func (s *loadingWriter) start(ctx context.Context) {
s.done = make(chan struct{})
defer close(s.done)
defer func() {
// Skip cleanup writes if the client disconnected — the connection
// is being torn down and flushing against it will panic.
if s.ctx.Err() != nil {
return
}
duration := time.Since(s.startTime)
s.sendData("\n")
s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Time{}
ticker := time.NewTicker(s.tickDuration)
defer ticker.Stop()
if s.loopStarted != nil {
close(s.loopStarted)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.pendingMu.Lock()
update := s.pendingUpdate
s.pendingUpdate = ""
s.pendingMu.Unlock()
if update != "" {
s.sendData("\n")
s.sendInline(update)
s.sendData(" ")
lastRemarkTime = time.Now()
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendData("\n")
s.sendInline(remark)
s.sendData(" ")
lastRemarkTime = time.Now()
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool {
if s.done == nil {
return true
}
select {
case <-s.done:
return true
case <-time.After(timeout):
return false
}
}
func (s *loadingWriter) sendInline(text string) {
chunkSize := 10
if s.charPerSecond > 0 {
chunkSize = max(3, int(s.charPerSecond)/15)
}
runes := []rune(text)
for i := 0; i < len(runes); {
select {
case <-s.ctx.Done():
return
default:
}
end := i + chunkSize
if end > len(runes) {
end = len(runes)
}
chunk := string(runes[i:end])
s.sendData(chunk)
i = end
if i < len(runes) && s.charPerSecond > 0 {
time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond))
}
}
}
func (s *loadingWriter) sendLine(line string) {
if line == "" {
s.sendData("\n")
return
}
s.sendInline(line)
s.sendData("\n")
}
func (s *loadingWriter) sendData(data string) {
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err)
return
}
s.writeMu.Lock()
defer s.writeMu.Unlock()
// Once ServeHTTP has reclaimed the writer (release), writing/flushing it
// races the real handler or panics on a finalized response. Stop here.
if s.released {
return
}
if _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData); err != nil {
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
return
}
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
// release fences the loadingWriter off from the underlying ResponseWriter.
// After it returns, the streaming goroutine will not write to or flush the
// writer again: any in-flight write completes under writeMu first, and later
// writes short-circuit on released. The caller can then safely hand the writer
// to the real handler or let ServeHTTP return without racing a finalized
// response (a use-after-return Flush panics on the recycled *bufio.Writer).
func (s *loadingWriter) release() {
s.writeMu.Lock()
s.released = true
s.writeMu.Unlock()
}
func (s *loadingWriter) Header() http.Header {
return s.writer.Header()
}
func (s *loadingWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *loadingWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
func (s *loadingWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+133
View File
@@ -0,0 +1,133 @@
package router
var loadingRemarks = []string{
"Still faster than your last standup meeting",
"Reticulating splines",
"Waking up the hamsters",
"Teaching the model manners",
"Convincing the GPU to participate",
"Loading weights (they're heavy)",
"Please enjoy this elevator music in your head",
"Pretending to be productive",
"Reading the entire internet, page by page",
"Staring at the abyss, the abyss is buffering",
"Applying layer after layer of disembodied cognition",
"Remembering everything it forgot during quantization",
"Counting to 405 billion, one parameter at a time",
"Summoning the stochastic parroting",
"Hold on, the GPU is questioning its existence",
"Deciding which facts to hallucinate today",
"Untangling the transformer spaghetti",
"Warming up the token soup",
"Your prompt is in a queue, behind 7 billion other thoughts",
"Running `sudo apt-get install intelligence`",
"Defragmenting the latent space",
"Polishing each matrix multiplication by hand",
"Whispering sweet nothings to the attention heads",
"Aligning with human values, one reluctant epoch at a time",
"The model is thinking about what it's about to think about",
"Loading... and by loading we mean making you wait",
"Spinning up the cloud GPU, please be patient while we burn your credits",
"Applying duct tape to the context window",
"Bribing the GPU scheduler for a timeslice",
"Would you like to hear a fun fact while we load? Too bad.",
"Hot swapping your sanity for an LLM",
"Compressing optimism into FP16",
"Ignoring 90% of the attention to save you 50% of the time",
"Counting the exact same thing three times just to be sure",
"Sorry, the inference you have reached is not in service",
"Rotating the positional encodings counterclockwise for good luck",
"Your call is very important to us. Please continue to hold.",
"Unpacking the blobs. All 300GB of them.",
"Initializing the thing that initializes the other thing",
"Converting electricity into existential dread",
"Flattening the curve... wait, the tensor. Flattening the tensor.",
"Fetching the fetch of a fetch, callback hell edition",
"The GPU is at 100%. The fan is now a helicopter.",
"Baking the weights at 350° for a golden-brown inference",
"Recalibrating the confidence of things it's still wrong about",
"Have you tried turning it off and on again? No? Good, wait here.",
"Simulating deep thought by pausing dramatically",
"Loading the model that knows more than you but still can't count r's in 'strawberry'",
"Convincing CUDA to cooperate. This may take a while.",
"VRAM: 23.9GB used of 24GB. Living on the edge.",
"Processing your request with the urgency of a DMV employee",
"This model was trained on the entire internet, including that embarrassing blog you wrote in 2008",
"Dispatching tokens through a series of increasingly confused matrix multiplies",
"Gently lowering your expectations",
"Applying softmax to our feelings about this load time",
"Autoregressively generating disappointment, one token at a time",
"The magic is happening. Somewhere. Probably.",
"Synchronizing the parallel processes that run in parallel but really don't",
"Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.",
"Loading... just like it said 30 seconds ago. And will say 30 seconds from now.",
"Pre-warming the cache so the first query is only slightly slower than the rest",
"Have you considered that maybe your question wasn't worth all this compute?",
"Downloading more RAM (no, really, we're mmap-ing the weights)",
"Translating your prompt into math it barely understands",
"Estimating your time remaining with 0% accuracy",
"Buffering enthusiasm",
"Model is loading. Go make some coffee. Or a three-course meal.",
"Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'",
"Polling for readiness in a loop that would make your CS professor weep",
"Performing percussive maintenance on the attention mechanism",
"This loading screen is singlehandedly reversing climate progress",
"Decompressing the hopes and dreams of thousands of underpaid labelers",
"Filling the key-value cache with the ghost of prompts past",
"Currently at step 3 of 9,742 of loading. We'll get there. Eventually.",
"If you stare at the spinner, it spins slower. It's science.",
"Multiplying matricies with the enthusiasm of a teenager doing chores",
"Applying `torch.nap()` until the model feels refreshed",
"Reacquainting the model with the concept of 'facts' it forgot during fine-tuning",
"Sorry for the wait. No, wait, we're not actually sorry.",
"Your GPU is now a space heater with a side hustle in linear algebra",
"Allocating memory like a billionaire allocates tax avoidance strategies",
"The model saw \"As an AI language model\" and won't stop saying it now",
"Installing dependencies you didn't know existed and will never use again",
"Re-reading 'Attention Is All You Need' for the 400th time",
"Convincing the embedding layer that context is overrated",
"Manually untangling the residual connections with a tiny comb",
"On hold with the cloud provider trying to explain why 8 H100s isn't enough",
"Adjusting temperatures: model is 0.7, server room is 104°F",
"Please hold while we justify this electricity bill to accounting",
"Stacking decoder blocks like a Jenga tower at a LAN party",
"Compensating for your lack of patience with our lack of speed",
"This is a loading screen comment. Loading screens have comments now. Welcome to the future.",
"Processing the entire works of Shakespeare backwards just in case",
"The model is loading slower than your last `npm install`",
"Rehearsing plausible-sounding explanations for why it got everything wrong",
"Populating the context with filler while you wait for actual content",
"Optimizing for BLEU score, which definitely correlates with making you laugh",
"Generating an embedding for each and every letter of the alphabet, individually",
"Coming soon: llama-swap v2 with actual performance improvements. Probably.",
"Loading a model larger than your attention span",
"Performing a seance to invoke the spirit of Geoff Hinton",
"Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.",
"Converting all the internet's bad opinions into a surprisingly useful autocomplete",
"Laying down each layer with the care of a Michelin-starred pastry chef",
"Checking if the model still thinks birds are government drones. Yep.",
"Activating the neurons responsible for 'I cannot assist with that request'",
"This model was trained on the same internet that brought you Rickrolling. You're welcome.",
"Realigning the alignment so it aligns with the previous alignment",
"Running `nvidia-smi` and sighing heavily",
"If you close your eyes, the loading bar moves faster. Proven by science.",
"EULA said 'by using this software you agree to wait forever' and you clicked Accept",
"Zipping the GPUs to make them go faster",
"Padding the context window with existential padding",
"We could have used a smaller model but someone wanted 'quality'",
"Disentangling the latent space into something resembling coherence",
"Slow is smooth, smooth is fast, but this is just slow",
"Memory-mapping like it's a AAA title from 2012",
"Your patience has been tokenized and added to the training set. Thank you for your contribution.",
"Loading is CPU-bound and your CPU is busy regretting its life choices",
"Exploring the high-dimensional manifold of ways to say 'just a moment'",
"The model is experiencing a brief but intense moment of imposter syndrome",
"Initializing 7B parameters by rolling 7B 16-sided dice",
"Panic! at the disk I/O",
"Intelligence is loading... your definition of intelligence may vary",
"This model was distilled. Unlike your patience, which is evaporating.",
"Unzipping the model. It's a .gguf file, not a metaphor.",
"Running inference on the concept of 'soon' to estimate remaining time",
"Loading with all the speed of a government-funded IT project",
"A blank terminal is a terrible thing to waste. Here's a loading message instead.",
}
+328
View File
@@ -0,0 +1,328 @@
package router
import (
"bufio"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
}
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
t.Errorf("Cache-Control: want no-cache, got %q", cc)
}
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
t.Errorf("Connection: want keep-alive, got %q", conn)
}
body := w.Body.String()
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected SSE data: prefix, got: %s", body)
}
content := extractStreamedContent(body)
if !strings.Contains(content, "━━━━━\n") {
t.Errorf("missing separator in streamed content: %q", content)
}
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
t.Errorf("missing initial message in streamed content: %q", content)
}
}
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.WriteHeader(http.StatusCreated)
if w.Code != http.StatusOK {
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
}
}
func TestLoadingWriter_WritePassthrough(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.Write([]byte("hello"))
lw.Flush()
body := w.Body.String()
if !strings.Contains(body, "hello") {
t.Errorf("Write passthrough failed, body: %s", body)
}
}
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
if !strings.Contains(body, "Done!") {
t.Errorf("expected Done! message, body: %s", body)
}
}
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
lw.setUpdate("custom status message")
time.Sleep(50 * time.Millisecond)
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
content := extractStreamedContent(body)
if !strings.Contains(content, "custom status message") {
t.Errorf("expected setUpdate message in output, got: %q", content)
}
}
func TestLoadingWriter_SendDataFormat(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.sendData("hello world")
body := w.Body.String()
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
}
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected data: prefix, got: %s", body)
}
}
func TestLoadingWriter_SendLine(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.charPerSecond = 0
// Capture only the content from this sendLine call
before := w.Body.Len()
lw.sendLine("line content")
after := w.Body.Len()
chunkBody := w.Body.String()[before:after]
content := extractStreamedContent(chunkBody)
if content != "line content\n" {
t.Errorf("expected complete streamed line, got: %q", content)
}
}
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
lw.start(ctx)
close(done)
}()
<-lw.loopStarted
time.Sleep(50 * time.Millisecond)
cancel()
<-done
body := w.Body.String()
lines := countSSEMessages(body)
if lines < 2 {
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
}
}
func TestLoadingWriter_ReqStored(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if lw.req != req {
t.Fatal("req not stored")
}
}
func TestIsLoadingPath(t *testing.T) {
tests := []struct {
path string
want bool
}{
{"/v1/chat/completions", true},
{"/v1/chat/completions/extra", true},
{"/v1/completions", false},
{"/v1/embeddings", false},
{"/health", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := isLoadingPath(tt.path); got != tt.want {
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func TestExtractContext_Streaming_GET(t *testing.T) {
tests := []struct {
name string
query string
wantStreaming bool
}{
{"streaming true", "model=llama3&stream=true", true},
{"streaming false", "model=llama3&stream=false", false},
{"no stream param", "model=llama3", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
got, err := ExtractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if got.Streaming != tt.wantStreaming {
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
}
})
}
}
func TestExtractContext_Streaming_JSON(t *testing.T) {
tests := []struct {
name string
body string
wantStreaming bool
}{
{"streaming true", `{"model":"llama3","stream":true}`, true},
{"streaming false", `{"model":"llama3","stream":false}`, false},
{"no stream param", `{"model":"llama3"}`, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
got, err := ExtractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if got.Streaming != tt.wantStreaming {
t.Errorf("Streaming: want %v, got %v", tt.wantStreaming, got.Streaming)
}
})
}
}
func TestExtractContext_Streaming_URLEncodedForm(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader("model=whisper-1&stream=true"))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := ExtractContext(r)
if err != nil {
t.Fatalf("ExtractContext: %v", err)
}
if !got.Streaming {
t.Error("Streaming should be true")
}
}
func countSSEMessages(s string) int {
scanner := bufio.NewScanner(strings.NewReader(s))
count := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
count++
}
}
return count
}
func extractStreamedContent(body string) string {
var result strings.Builder
scanner := bufio.NewScanner(strings.NewReader(body))
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonData := strings.TrimPrefix(line, "data: ")
var msg struct {
Choices []struct {
Delta struct {
ReasoningContent string `json:"reasoning_content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
continue
}
if len(msg.Choices) > 0 {
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
}
}
return result.String()
}
+101
View File
@@ -0,0 +1,101 @@
package router
import (
"fmt"
"sort"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type Matrix struct {
*baseRouter
}
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
if conf.Matrix == nil {
return nil, fmt.Errorf("matrix router requires a matrix configuration")
}
planner := &matrixPlanner{
solver: newMatrixSolver(conf.ExpandedSets, conf.Matrix.ResolvedEvictCosts()),
logger: proxylog,
}
// Build a process for every model in the config. Any model can run alone
// even if it is not part of a set; this mirrors proxy.NewMatrix.
processes := make(map[string]process.Process, len(conf.Models))
base := newBaseRouter("matrix", conf, processes, planner, proxylog)
planner.processes = processes
for mid, modelCfg := range conf.Models {
procLog := logmon.NewWriter(upstreamlog)
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
if err != nil {
base.shutdownFn()
base.procCancel()
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
}
processes[mid] = p
}
r := &Matrix{baseRouter: base}
go base.run()
return r, nil
}
// matrixPlanner decides evictions by asking the matrix solver against the
// current running set.
type matrixPlanner struct {
solver *matrixSolver
processes map[string]process.Process
logger *logmon.Monitor
}
func (p *matrixPlanner) EvictionFor(target string, alsoRunning []string) []string {
return p.solver.Solve(target, p.runningSet(alsoRunning)).Evict
}
func (p *matrixPlanner) OnSwapStart(target string) {
running := p.runningModels()
result := p.solver.Solve(target, running)
switch {
case len(result.Evict) > 0:
p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
case len(running) == 0:
p.logger.Infof("matrix: model=%s starting (no models running)", target)
default:
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
}
}
func (p *matrixPlanner) runningModels() []string {
return p.runningSet(nil)
}
// runningSet returns the union of live processes (State != Stopped/Shutdown)
// and any extra IDs the baseRouter has already committed to loading but which
// the process state machine has not yet reflected.
func (p *matrixPlanner) runningSet(alsoRunning []string) []string {
seen := make(map[string]struct{}, len(p.processes))
var running []string
for id, proc := range p.processes {
st := proc.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
seen[id] = struct{}{}
running = append(running, id)
}
for _, id := range alsoRunning {
if _, dup := seen[id]; dup {
continue
}
seen[id] = struct{}{}
running = append(running, id)
}
sort.Strings(running)
return running
}
+132
View File
@@ -0,0 +1,132 @@
package router
import (
"slices"
"github.com/mostlygeek/llama-swap/internal/config"
)
// matrixSolver contains pure swap-decision logic with no Process dependencies.
// It is safe for concurrent reads after construction.
type matrixSolver struct {
expandedSets []config.ExpandedSet // all valid model combinations
evictCosts map[string]int // real model name -> eviction cost (default 1)
modelToSets map[string][]int // model name -> indices into expandedSets
}
func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver {
modelToSets := make(map[string][]int)
for i, es := range expandedSets {
for _, model := range es.Models {
modelToSets[model] = append(modelToSets[model], i)
}
}
return &matrixSolver{
expandedSets: expandedSets,
evictCosts: evictCosts,
modelToSets: modelToSets,
}
}
// solveResult describes what the solver decided.
type solveResult struct {
Evict []string // running models that must be stopped
TargetSet []string // the chosen set of models (for informational purposes)
SetName string // name of the chosen set
DSL string // original DSL expression for the chosen set
TotalCost int // total eviction cost
}
// Solve determines which models to evict when a model is requested.
//
// Algorithm:
// 1. If requestedModel is already running, no eviction needed.
// 2. Find all sets containing requestedModel.
// 3. If no sets found, the model runs alone; evict all running models.
// 4. For each candidate set, compute cost = sum of evict_costs for running
// models NOT in that set.
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
// 6. Return models to evict and the chosen set.
func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult {
if slices.Contains(runningModels, requestedModel) {
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
return solveResult{
TargetSet: runningModels,
SetName: setName,
DSL: dsl,
}
}
candidateIndices := s.modelToSets[requestedModel]
// Model not in any set: runs alone, evict everything.
if len(candidateIndices) == 0 {
evict := make([]string, len(runningModels))
copy(evict, runningModels)
return solveResult{
Evict: evict,
TargetSet: []string{requestedModel},
}
}
bestCost := -1
bestIdx := -1
for _, idx := range candidateIndices {
setModels := s.expandedSets[idx].Models
cost := 0
for _, running := range runningModels {
if !slices.Contains(setModels, running) {
cost += s.evictCost(running)
}
}
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
bestCost = cost
bestIdx = idx
}
}
chosen := s.expandedSets[bestIdx]
var evict []string
for _, running := range runningModels {
if !slices.Contains(chosen.Models, running) {
evict = append(evict, running)
}
}
return solveResult{
Evict: evict,
TargetSet: chosen.Models,
SetName: chosen.SetName,
DSL: chosen.DSL,
TotalCost: bestCost,
}
}
// findMatchingSet finds the expanded set that contains all running models.
// Returns the set name and DSL, or empty strings if no match.
func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
for _, idx := range s.modelToSets[requestedModel] {
set := s.expandedSets[idx]
allInSet := true
for _, m := range runningModels {
if !slices.Contains(set.Models, m) {
allInSet = false
break
}
}
if allInSet {
return set.SetName, set.DSL
}
}
return "", ""
}
func (s *matrixSolver) evictCost(model string) int {
if cost, ok := s.evictCosts[model]; ok {
return cost
}
return 1
}
+244
View File
@@ -0,0 +1,244 @@
package router
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// newTestMatrix builds a Matrix router from supplied processes, bypassing
// NewMatrix's call to process.New.
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
t.Helper()
logger := logmon.NewWriter(io.Discard)
planner := &matrixPlanner{
solver: newMatrixSolver(expanded, evictCosts),
processes: processes,
logger: logger,
}
base := newBaseRouter("matrix", conf, processes, planner, logger)
base.testProcessed = make(chan struct{}, 64)
r := &Matrix{baseRouter: base}
go base.run()
t.Cleanup(func() {
if !r.shuttingDown.Load() {
_ = r.Shutdown(time.Second)
}
})
return r
}
func baseMatrixConfig() config.Config {
return config.Config{
HealthCheckTimeout: 5,
Matrix: &config.MatrixConfig{},
}
}
// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers
// eviction of running models that are not in any shared set with it.
func TestMatrix_SwapEvictsConflicting(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0) // park a Run goroutine so Stop has something to release
b := newFakeProcess("b")
b.autoReady = true
// Two single-model sets: a and b never coexist, so loading b must evict a.
expanded := []config.ExpandedSet{
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
r.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestMatrix_CoexistInSet verifies that a model is not evicted when the target
// shares a set with it (the fast path applies if the target is already ready).
func TestMatrix_CoexistInSet(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
// Both fit in s_ab, so b's swap should not stop a.
expanded := []config.ExpandedSet{
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
r.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestMatrix_CoexistingSetParallel verifies that two models that share an
// expanded set load in parallel — the solver returns empty Evict for both,
// the collision predicate clears them, and both swaps run together.
func TestMatrix_CoexistingSetParallel(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
expanded := []config.ExpandedSet{
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
r.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, r.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
r.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, r.testProcessed, 1)
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got)
}
}
// TestMatrix_IncompatibleQueues verifies that the second request for a model
// that cannot coexist with the in-flight first model queues until the first
// completes, and then evicts it. This exercises the alsoRunning hint via the
// matrix solver's union into runningSet.
func TestMatrix_IncompatibleQueues(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
expanded := []config.ExpandedSet{
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
r.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, r.testProcessed, 1)
// B arrives before A transitions to StateStarting. The solver sees A via
// alsoRunning and returns evict=[a], so collidesWith forces B to queue.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
r.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, r.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started in parallel: runCalls=%d want 0", got)
}
<-a.runStarted
a.markReady()
waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a
<-pb.runStarted
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
}
}
// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule:
// when multiple candidate sets have equal eviction cost, the earlier-defined
// set wins.
func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) {
expanded := []config.ExpandedSet{
{SetName: "first", DSL: "a & b", Models: []string{"a", "b"}},
{SetName: "second", DSL: "a & c", Models: []string{"a", "c"}},
}
s := newMatrixSolver(expanded, nil)
// No models running, request "a": both sets have cost 0 and contain a.
// Definition order: "first" wins.
result := s.Solve("a", nil)
if result.SetName != "first" {
t.Errorf("SetName=%q want %q", result.SetName, "first")
}
}
// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer
// the solver toward a cheaper set.
func TestMatrixSolver_EvictCostsPreferred(t *testing.T) {
// b is expensive to evict; c is cheap. Request "a" with both b and c
// running. The solver should pick the set that keeps b.
expanded := []config.ExpandedSet{
{SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10)
{SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1)
}
s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1})
result := s.Solve("a", []string{"b", "c"})
if result.SetName != "a_with_b" {
t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b")
}
if len(result.Evict) != 1 || result.Evict[0] != "c" {
t.Errorf("Evict=%v want [c]", result.Evict)
}
}
+188
View File
@@ -0,0 +1,188 @@
package router
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
type peerMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type Peer struct {
cfg config.Config
logger *logmon.Monitor
peers map[string]*peerMember
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
inflight sync.WaitGroup
}
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
peers := cfg.Peers
modelMap := make(map[string]*peerMember)
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
for _, peerID := range peerIDs {
peer := peers[peerID]
peerTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
}
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
reverseProxy.Transport = peerTransport
originalDirector := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = req.URL.Host
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
for _, modelID := range peer.Models {
if _, found := modelMap[modelID]; found {
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
modelMap[modelID] = pp
}
}
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
return &Peer{
cfg: cfg,
logger: logger,
peers: modelMap,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
}, nil
}
func (r *Peer) Handles(model string) bool {
_, ok := r.peers[model]
return ok
}
func (r *Peer) Shutdown(timeout time.Duration) error {
if !r.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("shutdown already in progress")
}
if timeout == 0 {
r.shutdownFn()
r.inflight.Wait()
return nil
}
done := make(chan struct{})
go func() {
r.inflight.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-time.After(timeout):
r.shutdownFn()
r.inflight.Wait()
return fmt.Errorf("peer shutdown timed out after %v", timeout)
}
}
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.shuttingDown.Load() {
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
return
}
r.inflight.Add(1)
defer r.inflight.Done()
data, err := FetchContext(req, r.cfg)
if err != nil {
SendError(w, req, err)
return
}
pp, found := r.peers[data.ModelID]
if !found {
r.logger.Warnf("peer model not found: %s", data.ModelID)
SendError(w, req, ErrNoPeerModelFound)
return
}
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
if pp.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
req.Header.Set("x-api-key", pp.apiKey)
}
// Cancel the proxy request when the client disconnects or shutdown times out.
// AfterFunc links both parent contexts to our child without a goroutine leak.
ctx, cancel := context.WithCancel(context.Background())
stopReq := context.AfterFunc(req.Context(), cancel)
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
req = req.WithContext(ctx)
pp.reverseProxy.ServeHTTP(w, req)
stopShutdown()
stopReq()
cancel()
}
+611
View File
@@ -0,0 +1,611 @@
package router
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
var testLogger = logmon.NewWriter(os.Stdout)
func init() {
testLogger.SetLogLevel(logmon.LevelWarn)
}
func TestNewPeer_EmptyPeers(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
if pr == nil {
t.Fatal("expected non-nil Peer")
}
if len(pr.peers) != 0 {
t.Fatalf("expected empty peers map, got %d entries", len(pr.peers))
}
}
func TestNewPeer_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 2 {
t.Fatalf("expected 2 entries, got %d", len(pr.peers))
}
if _, ok := pr.peers["model-a"]; !ok {
t.Error("expected model-a to be mapped")
}
if _, ok := pr.peers["model-b"]; !ok {
t.Error("expected model-b to be mapped")
}
if _, ok := pr.peers["model-c"]; ok {
t.Error("expected model-c to not be mapped")
}
}
func TestNewPeer_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 4 {
t.Fatalf("expected 4 entries, got %d", len(pr.peers))
}
for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} {
if _, ok := pr.peers[m]; !ok {
t.Errorf("expected %s to be mapped", m)
}
}
}
func TestNewPeer_DuplicateModel(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 1 {
t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers))
}
if _, ok := pr.peers["duplicate-model"]; !ok {
t.Error("expected duplicate-model to be mapped")
}
}
func TestPeer_ServeHTTP_Success(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
if w.Body.String() != "response from peer" {
t.Errorf("expected 'response from peer', got %q", w.Body.String())
}
}
func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if receivedAuthHeader != "Bearer secret-api-key" {
t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader)
}
}
func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "",
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if receivedAuthHeader != "" {
t.Errorf("expected no auth header, got %q", receivedAuthHeader)
}
}
func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if !strings.HasPrefix(receivedHost, "127.0.0.1:") {
t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost)
}
}
func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Header().Get("X-Accel-Buffering") != "no" {
t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering"))
}
}
func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "shutting down") {
t.Errorf("expected 'shutting down' in body, got %q", w.Body.String())
}
}
func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
started := make(chan struct{})
released := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-released
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
}()
<-started
shutdownDone := make(chan error, 1)
go func() {
shutdownDone <- pr.Shutdown(500 * time.Millisecond)
}()
// Shutdown should be waiting on inflight. If it finished already something is wrong.
time.Sleep(100 * time.Millisecond)
select {
case err := <-shutdownDone:
t.Errorf("shutdown completed before inflight finished: %v", err)
default:
}
close(released)
wg.Wait()
select {
case err := <-shutdownDone:
if err != nil {
t.Errorf("shutdown errored after inflight completed: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("shutdown did not complete after inflight finished")
}
}
func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
started := make(chan struct{})
released := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-released
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "test-model", ModelID: "test-model"}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
}()
<-started
err = pr.Shutdown(100 * time.Millisecond)
if err == nil {
t.Error("expected timeout error from shutdown")
}
close(released)
wg.Wait()
}
func TestPeer_ShutdownMultiple(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err == nil {
t.Error("expected error on second shutdown")
}
if !strings.Contains(err.Error(), "already in progress") {
t.Errorf("expected 'already in progress', got %q", err.Error())
}
}
func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"extracted-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`)
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"context-model"},
},
"peer2": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"body-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
*req = *req.WithContext(SetContext(req.Context(), ReqContextData{Model: "context-model", ModelID: "context-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestNewPeer_CustomTimeouts(t *testing.T) {
proxyURL, _ := url.Parse("http://localhost:8080")
peers := config.PeerDictionaryConfig{
"test-peer": config.PeerConfig{
Proxy: "http://localhost:8080",
ProxyURL: proxyURL,
Models: []string{"model1"},
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 300,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
member, ok := pr.peers["model1"]
if !ok {
t.Fatal("expected model1 to be mapped")
}
transport, ok := member.reverseProxy.Transport.(*http.Transport)
if !ok {
t.Fatal("expected Transport to be *http.Transport")
}
if transport.ResponseHeaderTimeout != 300*time.Second {
t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout)
}
if transport.TLSHandshakeTimeout != 15*time.Second {
t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout)
}
if transport.ExpectContinueTimeout != 2*time.Second {
t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout)
}
if transport.IdleConnTimeout != 120*time.Second {
t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout)
}
if !transport.ForceAttemptHTTP2 {
t.Error("expected ForceAttemptHTTP2 to be true")
}
}
+199
View File
@@ -0,0 +1,199 @@
package router
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/tidwall/gjson"
)
type contextkey struct {
name string
}
type ReqContextData struct {
Model string
ModelID string
Streaming bool
SendLoadingState bool
}
var (
ErrNoModelInContext = fmt.Errorf("no model in request context")
ErrNoRouterFound = fmt.Errorf("no router found for model")
ErrNoPeerModelFound = fmt.Errorf("peer model not found")
ErrNoLocalModelFound = fmt.Errorf("local model not found")
ContextKey = &contextkey{"context"}
)
type Router interface {
// Shutdown blocks until the router has shutdown returning nil
// when the router has shutdown successfully.
//
// timeout controls how long to wait for inflight requests to finish. After
// the timeout all inflight requests will be cancelled.
Shutdown(timeout time.Duration) error
// ServeHTTP implements the http.Handler and requests coming in will
// trigger any model swapping and routing logic.
ServeHTTP(http.ResponseWriter, *http.Request)
// Handles reports whether this router can serve requests for the given model.
Handles(model string) bool
}
// LocalRouter is a Router backed by local processes whose state can be
// inspected and which can be individually stopped. Peer routers, which only
// forward to remote hosts, do not implement it.
type LocalRouter interface {
Router
// RunningModels returns the current state of every process that is not
// stopped or shut down, keyed by model ID.
RunningModels() map[string]process.ProcessState
// Unload stops the named models, or every running model when none are
// named. It blocks until each targeted process has stopped.
Unload(timeout time.Duration, models ...string)
// ProcessLogger returns the log monitor for the named model's process.
// modelID must be a real (non-alias) config key. Returns false when the
// model is not known to this router.
ProcessLogger(modelID string) (*logmon.Monitor, bool)
}
// FetchContext will attempt to get the model id from the context then
// from the model body. If it extracts the model from the body it will
// store the model in the context for downstream handlers. An error
// will be returned when model can not be fetch from either location.
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
data, ok := ReadContext(r.Context())
if ok {
return data, nil
}
if data, err := ExtractContext(r); err == nil {
realName, _ := cfg.RealModelName(data.Model)
if realName == "" {
realName = data.Model
}
data.ModelID = realName
if mc, ok := cfg.Models[realName]; ok {
data.SendLoadingState = mc.SendLoadingState != nil && *mc.SendLoadingState
}
*r = *r.WithContext(SetContext(r.Context(), data))
return data, nil
}
return ReqContextData{}, ErrNoModelInContext
}
func SetContext(ctx context.Context, data ReqContextData) context.Context {
return context.WithValue(ctx, ContextKey, data)
}
func ReadContext(ctx context.Context) (ReqContextData, bool) {
data, ok := ctx.Value(ContextKey).(ReqContextData)
return data, ok
}
// ExtractContext pulls the model name from an HTTP request without consuming the
// body. For GET requests it reads the "model" query parameter. For POST
// requests it inspects Content-Type and parses JSON, multipart/form-data, or
// application/x-www-form-urlencoded bodies. The request body is always restored
// before returning so downstream handlers — including reverse proxies that
// forward raw bytes upstream — can still read it.
func ExtractContext(r *http.Request) (ReqContextData, error) {
if r.Method == http.MethodGet {
if model := r.URL.Query().Get("model"); model != "" {
return ReqContextData{Model: model, Streaming: r.URL.Query().Get("stream") == "true"}, nil
}
return ReqContextData{}, fmt.Errorf("missing 'model' query parameter")
}
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return ReqContextData{}, fmt.Errorf("error reading request body: %w", err)
}
defer func() {
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}()
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
model := gjson.GetBytes(bodyBytes, "model").String()
if model == "" {
return ReqContextData{}, fmt.Errorf("missing or empty 'model' in JSON body")
}
return ReqContextData{Model: model, Streaming: gjson.GetBytes(bodyBytes, "stream").Bool()}, nil
}
// Form parsers read from r.Body, so feed them a fresh reader over the
// buffered bytes. The deferred restore above will reset r.Body again
// after parsing.
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if strings.Contains(contentType, "multipart/form-data") {
if err := r.ParseMultipartForm(32 << 20); err != nil {
return ReqContextData{}, fmt.Errorf("error parsing multipart form: %w", err)
}
} else {
if err := r.ParseForm(); err != nil {
return ReqContextData{}, fmt.Errorf("error parsing form: %w", err)
}
}
if model := r.FormValue("model"); model != "" {
return ReqContextData{Model: model, Streaming: r.FormValue("stream") == "true"}, nil
}
return ReqContextData{}, fmt.Errorf("missing 'model' parameter")
}
func SendError(w http.ResponseWriter, r *http.Request, err error) {
switch {
case errors.Is(err, ErrNoModelInContext):
SendResponse(w, r, http.StatusNotFound, "no model id could be identified")
case errors.Is(err, ErrNoPeerModelFound):
SendResponse(w, r, http.StatusNotFound, "no peer found for requested model")
case errors.Is(err, ErrNoLocalModelFound):
SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
case errors.Is(err, ErrNoRouterFound):
SendResponse(w, r, http.StatusNotFound, "no router for requested model")
default:
SendResponse(w, r, http.StatusInternalServerError, fmt.Sprintf("unspecific error: %v", err))
}
}
// SendResponse detects what content type the client prefers and returns an error response in that format.
func SendResponse(w http.ResponseWriter, r *http.Request, status int, message string) {
// Check Accept header for preferred response format
acceptHeader := r.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/plain") {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf("llama-swap: %s", message)))
return
}
if strings.Contains(acceptHeader, "text/html") {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf(`<html><body><h1>llama-swap</h1><p>%s</p></body></html>`, message)))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf(`{"src":"llama-swap", "error": "%s"}`, message)))
}
+275
View File
@@ -0,0 +1,275 @@
package router
import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"net/url"
"strings"
"testing"
)
func TestExtractContext_GET(t *testing.T) {
tests := []struct {
name string
query string
wantModel string
wantErr bool
}{
{"model present", "model=llama3", "llama3", false},
{"model with slashes", "model=author/model-7b", "author/model-7b", false},
{"model missing", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/?"+tt.query, nil)
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_JSON(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantErr bool
}{
{"model present", `{"model":"llama3","stream":true}`, "llama3", false},
{"model with slashes", `{"model":"author/model-7b"}`, "author/model-7b", false},
{"model empty string", `{"model":""}`, "", true},
{"model key missing", `{"stream":true}`, "", true},
{"invalid json", `not-json`, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_URLEncodedForm(t *testing.T) {
tests := []struct {
name string
formModel string
wantModel string
wantErr bool
}{
{"model present", "whisper-1", "whisper-1", false},
{"model missing", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
form := url.Values{}
if tt.formModel != "" {
form.Set("model", tt.formModel)
}
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(form.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_MultipartForm(t *testing.T) {
tests := []struct {
name string
formModel string
wantModel string
wantErr bool
}{
{"model present", "whisper-1", "whisper-1", false},
{"model missing", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
if tt.formModel != "" {
fw, _ := mw.CreateFormField("model")
fw.Write([]byte(tt.formModel))
}
mw.Close()
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", &buf)
r.Header.Set("Content-Type", mw.FormDataContentType())
got, err := ExtractContext(r)
if (err != nil) != tt.wantErr {
t.Fatalf("wantErr=%v got err=%v", tt.wantErr, err)
}
if got.Model != tt.wantModel {
t.Errorf("want %q got %q", tt.wantModel, got.Model)
}
})
}
}
func TestExtractContext_JSONBodyRestored(t *testing.T) {
body := `{"model":"llama3","stream":true}`
r, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
if _, err := ExtractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if string(remaining) != body {
t.Errorf("body not restored: want %q got %q", body, string(remaining))
}
}
func TestExtractContext_MultipartBodyRestored(t *testing.T) {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
fw, _ := mw.CreateFormField("model")
fw.Write([]byte("whisper-1"))
ff, _ := mw.CreateFormFile("file", "audio.wav")
ff.Write([]byte("fake-audio-bytes"))
mw.Close()
original := buf.Bytes()
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(original))
r.Header.Set("Content-Type", mw.FormDataContentType())
if _, err := ExtractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if !bytes.Equal(remaining, original) {
t.Errorf("multipart body not restored: want %d bytes got %d bytes", len(original), len(remaining))
}
}
func TestExtractContext_URLEncodedBodyRestored(t *testing.T) {
body := "model=whisper-1&extra=value"
r, _ := http.NewRequest(http.MethodPost, "/v1/audio/transcriptions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if _, err := ExtractContext(r); err != nil {
t.Fatalf("ExtractContext: %v", err)
}
remaining, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading body after ExtractContext: %v", err)
}
if string(remaining) != body {
t.Errorf("url-encoded body not restored: want %q got %q", body, string(remaining))
}
}
func TestSetContext(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
data, ok := ctx.Value(ContextKey).(ReqContextData)
if !ok {
t.Fatalf("ContextKey not set or wrong type")
}
if data.Model != "llama3" {
t.Errorf("want %q got %q", "llama3", data.Model)
}
if data.ModelID != "llama3" {
t.Errorf("want %q got %q", "llama3", data.ModelID)
}
}
func TestSetContext_WithAlias(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"})
data, _ := ctx.Value(ContextKey).(ReqContextData)
if data.Model != "llama" {
t.Errorf("want requested %q got %q", "llama", data.Model)
}
if data.ModelID != "llama3" {
t.Errorf("want real %q got %q", "llama3", data.ModelID)
}
}
func TestSetContext_DoesNotMutateParent(t *testing.T) {
parent := context.Background()
_ = SetContext(parent, ReqContextData{Model: "llama3", ModelID: "llama3"})
if v := parent.Value(ContextKey); v != nil {
t.Errorf("parent context was mutated: %v", v)
}
}
func TestReadContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
wantReq string
wantReal string
wantBool bool
}{
{
name: "model present, same name",
ctx: SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"}),
wantReq: "llama3",
wantReal: "llama3",
wantBool: true,
},
{
name: "model present, aliased",
ctx: SetContext(context.Background(), ReqContextData{Model: "llama", ModelID: "llama3"}),
wantReq: "llama",
wantReal: "llama3",
wantBool: true,
},
{
name: "model absent",
ctx: context.Background(),
wantReq: "",
wantReal: "",
wantBool: false,
},
{
name: "model is empty string",
ctx: SetContext(context.Background(), ReqContextData{Model: "", ModelID: ""}),
wantReq: "",
wantReal: "",
wantBool: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotData, ok := ReadContext(tt.ctx)
if gotData.Model != tt.wantReq || gotData.ModelID != tt.wantReal || ok != tt.wantBool {
t.Errorf("want (%q, %q, %v) got (%q, %q, %v)", tt.wantReq, tt.wantReal, tt.wantBool, gotData.Model, gotData.ModelID, ok)
}
})
}
}
+266
View File
@@ -0,0 +1,266 @@
package server
import (
"encoding/json"
"net/http"
"sort"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/router"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// modelRecord is one entry in the OpenAI-compatible /v1/models listing.
type modelRecord struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
}
// handleListModels serves the OpenAI-compatible model listing: local models
// (with optional aliases) plus peer models.
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
created := time.Now().Unix()
data := make([]modelRecord, 0, len(s.cfg.Models))
newRecord := func(id, name, description string, metadata map[string]any) modelRecord {
rec := modelRecord{
ID: id,
Object: "model",
Created: created,
OwnedBy: "llama-swap",
Name: strings.TrimSpace(name),
Description: strings.TrimSpace(description),
}
if len(metadata) > 0 {
rec.Meta = map[string]any{"llamaswap": metadata}
}
return rec
}
for id, mc := range s.cfg.Models {
if mc.Unlisted {
continue
}
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata))
if s.cfg.IncludeAliasesInList {
for _, alias := range mc.Aliases {
if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata))
}
}
}
}
for peerID, peer := range s.cfg.Peers {
for _, modelID := range peer.Models {
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}))
}
}
sort.Slice(data, func(i, j int) bool { return data[i].ID < data[j].ID })
// Echo the Origin so browser clients can read the listing.
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"object": "list",
"data": data,
})
}
// runningModel is one entry in the /running listing.
type runningModel struct {
Model string `json:"model"`
State string `json:"state"`
Cmd string `json:"cmd"`
Proxy string `json:"proxy"`
TTL int `json:"ttl"`
Name string `json:"name"`
Description string `json:"description"`
}
// handleUnload stops every running local process. Peer models are remote and
// unaffected.
func (s *Server) handleUnload(w http.ResponseWriter, r *http.Request) {
s.local.Unload(0)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
// handleRunning lists local processes that are not stopped, joining each model
// ID against its config for the cmd/proxy/ttl/name/description metadata.
func (s *Server) handleRunning(w http.ResponseWriter, r *http.Request) {
states := s.local.RunningModels()
list := make([]runningModel, 0, len(states))
for id, state := range states {
mc := s.cfg.Models[id]
list = append(list, runningModel{
Model: id,
State: string(state),
Cmd: mc.Cmd,
Proxy: mc.Proxy,
TTL: mc.UnloadAfter,
Name: mc.Name,
Description: mc.Description,
})
}
sort.Slice(list, func(i, j int) bool { return list[i].Model < list[j].Model })
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{"running": list})
}
// discardResponseWriter satisfies http.ResponseWriter for preload requests,
// dropping the body while capturing the status code.
type discardResponseWriter struct {
header http.Header
status int
}
func (d *discardResponseWriter) Header() http.Header {
if d.header == nil {
d.header = make(http.Header)
}
return d.header
}
func (d *discardResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
func (d *discardResponseWriter) WriteHeader(status int) { d.status = status }
// startPreload fires a background GET / at every model named in
// Hooks.OnStartup.Preload so they are warm before the first real request.
// Preload names are already resolved to real model IDs by config loading.
func (s *Server) startPreload() {
models := s.cfg.Hooks.OnStartup.Preload
if len(models) == 0 {
return
}
go func() {
for _, modelID := range models {
if !s.local.Handles(modelID) {
s.proxylog.Warnf("preload: model %s is not a local model, skipping", modelID)
continue
}
s.proxylog.Infof("preloading model: %s", modelID)
req, err := http.NewRequestWithContext(s.shutdownCtx, http.MethodGet, "/", nil)
if err != nil {
continue
}
req = req.WithContext(router.SetContext(req.Context(), router.ReqContextData{Model: modelID, ModelID: modelID}))
dw := &discardResponseWriter{status: http.StatusOK}
s.local.ServeHTTP(dw, req)
success := dw.status < http.StatusBadRequest
if !success {
s.proxylog.Errorf("failed to preload model %s: status %d", modelID, dw.status)
}
event.Emit(shared.ModelPreloadedEvent{ModelName: modelID, Success: success})
}
}()
}
// handleMetrics serves Prometheus-format performance metrics. Returns 503 when
// performance monitoring is disabled.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
if s.perf == nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte("# performance monitor not available\n"))
return
}
s.perf.MetricsHandler().ServeHTTP(w, r)
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
func handleRootRedirect(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ui", http.StatusFound)
}
func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ui/models", http.StatusFound)
}
// handleUpstream proxies ANY request under /upstream/<model>/<path> directly to
// the model's process, bypassing model dispatch by body/query inspection.
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
upstreamPath := r.PathValue("upstreamPath")
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath)
if !found {
router.SendResponse(w, r, http.StatusNotFound, "model not found")
return
}
// Redirect /upstream/model to /upstream/model/ so relative URLs in upstream
// responses resolve. 301 for GET/HEAD, 308 otherwise to preserve the method.
if remainingPath == "/" && !strings.HasSuffix(r.URL.Path, "/") {
newPath := "/upstream/" + searchName + "/"
if r.URL.RawQuery != "" {
newPath += "?" + r.URL.RawQuery
}
if r.Method == http.MethodGet || r.Method == http.MethodHead {
http.Redirect(w, r, newPath, http.StatusMovedPermanently)
} else {
http.Redirect(w, r, newPath, http.StatusPermanentRedirect)
}
return
}
// Strip the /upstream/<model> prefix before forwarding.
r.URL.Path = remainingPath
// Pin the resolved model so the router skips body/query extraction.
*r = *r.WithContext(router.SetContext(r.Context(), router.ReqContextData{Model: searchName, ModelID: modelID}))
switch {
case s.local.Handles(modelID):
s.local.ServeHTTP(w, r)
case s.peer.Handles(modelID):
s.peer.ServeHTTP(w, r)
default:
router.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
}
}
// findModelInPath walks a slash-separated path, building up segments until one
// matches a configured model. This resolves model names that contain slashes
// (e.g. "author/model"). Returns the matched name, its real model ID, the
// remaining path, and whether a match was found.
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
name := ""
for i, part := range parts {
if part == "" {
continue
}
if name == "" {
name = part
} else {
name = name + "/" + part
}
if modelID, ok := cfg.RealModelName(name); ok {
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
+159
View File
@@ -0,0 +1,159 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
)
func TestServer_HandleListModels(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.cfg = config.Config{
Models: map[string]config.ModelConfig{
"visible": {Name: "Visible", Description: "a model"},
"hidden": {Unlisted: true},
},
Peers: config.PeerDictionaryConfig{
"peer1": {Models: []string{"remote-model"}},
},
}
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
req.Header.Set("Origin", "http://example.com")
s.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Errorf("Access-Control-Allow-Origin = %q", got)
}
var resp struct {
Data []modelRecord `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode: %v", err)
}
ids := map[string]bool{}
for _, m := range resp.Data {
ids[m.ID] = true
}
if !ids["visible"] || !ids["remote-model"] {
t.Errorf("missing expected models: %v", ids)
}
if ids["hidden"] {
t.Error("unlisted model should not appear")
}
}
func TestServer_HandleListModels_Aliases(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.cfg = config.Config{
IncludeAliasesInList: true,
Models: map[string]config.ModelConfig{
"real": {Aliases: []string{"nick"}},
},
}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
var resp struct {
Data []modelRecord `json:"data"`
}
json.Unmarshal(w.Body.Bytes(), &resp)
ids := map[string]bool{}
for _, m := range resp.Data {
ids[m.ID] = true
}
if !ids["real"] || !ids["nick"] {
t.Errorf("expected alias entry; got %v", ids)
}
}
func TestServer_FindModelInPath(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{
"author/model": {},
"simple": {},
}}
cases := []struct {
path string
wantName string
wantRem string
wantFound bool
}{
{"/simple/v1/chat", "simple", "/v1/chat", true},
{"/author/model/v1/chat", "author/model", "/v1/chat", true},
{"/author/model", "author/model", "/", true},
{"/missing/v1", "", "", false},
{"/", "", "", false},
}
for _, c := range cases {
name, _, rem, found := findModelInPath(cfg, c.path)
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
}
}
}
func TestServer_HandleUpstream(t *testing.T) {
local := newStubRouter([]string{"m1"}, "upstream-body")
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
t.Run("proxies to local", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat", nil))
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
t.Errorf("status=%d body=%q", w.Code, w.Body.String())
}
})
t.Run("redirects bare model path", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1", nil))
if w.Code != http.StatusMovedPermanently {
t.Errorf("status = %d, want 301", w.Code)
}
})
t.Run("unknown model 404", func(t *testing.T) {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/nope/v1", nil))
if w.Code != http.StatusNotFound {
t.Errorf("status = %d, want 404", w.Code)
}
})
}
func TestServer_HandleMetrics_Unavailable(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/metrics", nil))
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want 503", w.Code)
}
}
func TestServer_Redirects(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
for path, want := range map[string]string{"/": "/ui", "/upstream": "/ui/models"} {
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, path, nil))
if w.Code != http.StatusFound {
t.Errorf("%s: status = %d, want 302", path, w.Code)
}
if got := w.Header().Get("Location"); got != want {
t.Errorf("%s: Location = %q, want %q", path, got, want)
}
}
}
+270
View File
@@ -0,0 +1,270 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/perf"
"github.com/mostlygeek/llama-swap/internal/router"
"github.com/mostlygeek/llama-swap/internal/shared"
)
// apiModel is one entry in the /api/events modelStatus payload.
type apiModel struct {
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"`
Aliases []string `json:"aliases,omitempty"`
}
// modelStatus returns every configured model joined with its current process
// state (defaulting to "stopped"), followed by peer models.
func (s *Server) modelStatus() []apiModel {
running := s.local.RunningModels()
ids := make([]string, 0, len(s.cfg.Models))
for id := range s.cfg.Models {
ids = append(ids, id)
}
sort.Strings(ids)
models := make([]apiModel, 0, len(ids))
for _, id := range ids {
mc := s.cfg.Models[id]
state := "stopped"
if st, ok := running[id]; ok {
state = string(st)
}
models = append(models, apiModel{
Id: id,
Name: mc.Name,
Description: mc.Description,
State: state,
Unlisted: mc.Unlisted,
Aliases: mc.Aliases,
})
}
for peerID, peer := range s.cfg.Peers {
for _, modelID := range peer.Models {
models = append(models, apiModel{Id: modelID, PeerID: peerID})
}
}
return models
}
// handleAPIUnloadAll stops every running local process.
func (s *Server) handleAPIUnloadAll(w http.ResponseWriter, r *http.Request) {
s.local.Unload(0)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"msg": "ok"})
}
// handleAPIUnloadModel stops a single named local process.
func (s *Server) handleAPIUnloadModel(w http.ResponseWriter, r *http.Request) {
requested := strings.TrimPrefix(r.PathValue("model"), "/")
realName, found := s.cfg.RealModelName(requested)
if !found {
router.SendResponse(w, r, http.StatusNotFound, "model not found")
return
}
if !s.local.Handles(realName) {
router.SendResponse(w, r, http.StatusNotFound, "no local server found for requested model")
return
}
s.local.Unload(0, realName)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
// handleAPIMetrics serves the activity log as a JSON array.
func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
data, err := s.metrics.getMetricsJSON()
if err != nil {
router.SendResponse(w, r, http.StatusInternalServerError, "failed to get metrics")
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(data)
}
// handleAPIPerformance serves the buffered system/GPU stats, optionally
// filtered to samples after the ?after=<RFC3339> timestamp.
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
if s.perf == nil {
router.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
return
}
sysStats, gpuStats := s.perf.Current()
if afterStr := r.URL.Query().Get("after"); afterStr != "" {
after, err := time.Parse(time.RFC3339, afterStr)
if err != nil {
router.SendResponse(w, r, http.StatusBadRequest, "invalid 'after' timestamp, use RFC3339 format")
return
}
filteredSys := make([]perf.SysStat, 0, len(sysStats))
for _, st := range sysStats {
if st.Timestamp.After(after) {
filteredSys = append(filteredSys, st)
}
}
sysStats = filteredSys
filteredGpu := make([]perf.GpuStat, 0, len(gpuStats))
for _, g := range gpuStats {
if g.Timestamp.After(after) {
filteredGpu = append(filteredGpu, g)
}
}
gpuStats = filteredGpu
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"sys_stats": sysStats,
"gpu_stats": gpuStats,
})
}
// handleAPIVersion serves the build metadata.
func (s *Server) handleAPIVersion(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"version": s.build.Version,
"commit": s.build.Commit,
"build_date": s.build.Date,
})
}
// handleAPICapture returns the stored request/response capture for a metric ID.
func (s *Server) handleAPICapture(w http.ResponseWriter, r *http.Request) {
id, err := strconv.Atoi(r.PathValue("id"))
if err != nil {
router.SendResponse(w, r, http.StatusBadRequest, "invalid capture ID")
return
}
capture := s.metrics.getCaptureByID(id)
if capture == nil {
router.SendResponse(w, r, http.StatusNotFound, "capture not found")
return
}
jsonBytes, err := json.Marshal(capture)
if err != nil {
router.SendResponse(w, r, http.StatusInternalServerError, "failed to marshal capture")
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonBytes)
}
type messageType string
const (
msgTypeModelStatus messageType = "modelStatus"
msgTypeLogData messageType = "logData"
msgTypeMetrics messageType = "metrics"
msgTypeInFlight messageType = "inflight"
)
type messageEnvelope struct {
Type messageType `json:"type"`
Data string `json:"data"`
}
// handleAPIEvents streams server events (model status, log data, metrics,
// in-flight counts) to the client as Server-Sent Events.
func (s *Server) handleAPIEvents(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Content-Type-Options", "nosniff")
// prevent nginx from buffering SSE
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
router.SendResponse(w, r, http.StatusInternalServerError, "streaming unsupported")
return
}
// internal/event already has a 50K event buffer
// a 1K message buffer should be enough, watch the logs for the warning that the sendBuffer is full
sendBuffer := make(chan messageEnvelope, 1024)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
send := func(msg messageEnvelope) {
select {
case sendBuffer <- msg:
case <-ctx.Done():
s.proxylog.Warn("handleAPIEvents send suppressed due to context done")
default:
s.proxylog.Warn("handleAPIEvents sendBuffer full, dropped message")
}
}
sendModels := func() {
if data, err := json.Marshal(s.modelStatus()); err == nil {
send(messageEnvelope{Type: msgTypeModelStatus, Data: string(data)})
}
}
sendLogData := func(source string, data []byte) {
if j, err := json.Marshal(map[string]string{"source": source, "data": string(data)}); err == nil {
send(messageEnvelope{Type: msgTypeLogData, Data: string(j)})
}
}
sendMetrics := func(metrics []ActivityLogEntry) {
if j, err := json.Marshal(metrics); err == nil {
send(messageEnvelope{Type: msgTypeMetrics, Data: string(j)})
}
}
sendInFlight := func(total int) {
if j, err := json.Marshal(map[string]int{"total": total}); err == nil {
send(messageEnvelope{Type: msgTypeInFlight, Data: string(j)})
}
}
defer event.On(func(e shared.ProcessStateChangeEvent) { sendModels() })()
defer event.On(func(e shared.ConfigFileChangedEvent) { sendModels() })()
defer s.proxylog.OnLogData(func(data []byte) { sendLogData("proxy", data) })()
defer s.upstreamlog.OnLogData(func(data []byte) { sendLogData("upstream", data) })()
defer event.On(func(e ActivityLogEvent) { sendMetrics([]ActivityLogEntry{e.Metrics}) })()
defer event.On(func(e shared.InFlightRequestsEvent) { sendInFlight(e.Total) })()
// initial payload
sendLogData("proxy", s.proxylog.GetHistory())
sendLogData("upstream", s.upstreamlog.GetHistory())
sendModels()
sendMetrics(s.metrics.getMetrics())
sendInFlight(int(s.inflight.Current()))
for {
select {
case <-r.Context().Done():
return
case <-s.shutdownCtx.Done():
return
case msg := <-sendBuffer:
data, err := json.Marshal(msg)
if err != nil {
continue
}
fmt.Fprintf(w, "event:message\ndata:%s\n\n", data)
flusher.Flush()
}
}
}
+103
View File
@@ -0,0 +1,103 @@
package server
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestServer_InflightMiddleware(t *testing.T) {
c := &inflightCounter{}
mw := CreateInflightMiddleware(c)
var duringRequest int64
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
duringRequest = c.Current()
}))
handler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil))
if duringRequest != 1 {
t.Errorf("counter during request = %d, want 1", duringRequest)
}
if got := c.Current(); got != 0 {
t.Errorf("counter after request = %d, want 0", got)
}
}
func TestServer_APIVersion(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.build = BuildInfo{Version: "1.2.3", Commit: "deadbeef", Date: "2026-05-19"}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/version", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
var got map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got["version"] != "1.2.3" || got["commit"] != "deadbeef" || got["build_date"] != "2026-05-19" {
t.Errorf("body = %v", got)
}
}
func TestServer_APIMetrics_Empty(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/metrics", nil))
if w.Code != http.StatusOK {
t.Fatalf("status = %d", w.Code)
}
if body := strings.TrimSpace(w.Body.String()); body != "[]" {
t.Errorf("body = %q, want []", body)
}
}
func TestServer_APIPerformance_Unavailable(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/performance", nil))
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want 503", w.Code)
}
}
func TestServer_APIEvents_InitialPayload(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, "/api/events", nil).WithContext(ctx)
w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
s.ServeHTTP(w, req)
close(done)
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("handler did not return after context cancel")
}
body := w.Body.String()
for _, want := range []string{`"type":"modelStatus"`, `"type":"inflight"`, `"type":"logData"`} {
if !strings.Contains(body, want) {
t.Errorf("initial SSE payload missing %s; body=%q", want, body)
}
}
}
+135
View File
@@ -0,0 +1,135 @@
package server
import (
"encoding/base64"
"net/http"
"strings"
"github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/router"
)
// CreateAuthMiddleware returns middleware that validates API keys when the
// config declares any. It accepts the key via Authorization: Bearer,
// Authorization: Basic (password field), or x-api-key. On success the auth
// headers are stripped so they never leak to upstream. When no keys are
// configured the middleware is a pass-through.
func CreateAuthMiddleware(cfg config.Config) chain.Middleware {
keys := cfg.RequiredAPIKeys
return func(next http.Handler) http.Handler {
if len(keys) == 0 {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
provided := extractAPIKey(r)
valid := false
for _, key := range keys {
if provided == key {
valid = true
break
}
}
if !valid {
w.Header().Set("WWW-Authenticate", `Basic realm="llama-swap"`)
router.SendResponse(w, r, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
return
}
r.Header.Del("Authorization")
r.Header.Del("x-api-key")
next.ServeHTTP(w, r)
})
}
}
// extractAPIKey pulls a candidate API key from the request, preferring Basic,
// then Bearer, then x-api-key.
func extractAPIKey(r *http.Request) string {
var bearerKey, basicKey string
if auth := r.Header.Get("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ")
} else if strings.HasPrefix(auth, "Basic ") {
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
if parts := strings.SplitN(string(decoded), ":", 2); len(parts) == 2 {
basicKey = parts[1] // password field is the API key
}
}
}
}
switch {
case basicKey != "":
return basicKey
case bearerKey != "":
return bearerKey
default:
return r.Header.Get("x-api-key")
}
}
// CreateCORSMiddleware returns middleware that answers OPTIONS preflight
// requests with permissive CORS headers (see issues #81, #77, #42). Non-OPTIONS
// requests pass through untouched.
func CreateCORSMiddleware() chain.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodOptions {
next.ServeHTTP(w, r)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
w.Header().Set("Access-Control-Allow-Headers", sanitizeAccessControlRequestHeaderValues(headers))
} else {
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With")
}
w.Header().Set("Access-Control-Max-Age", "86400")
w.WriteHeader(http.StatusNoContent)
})
}
}
func isTokenChar(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
default:
return false
}
return true
}
// sanitizeAccessControlRequestHeaderValues drops any header names that contain
// characters outside the HTTP token grammar before echoing them back.
func sanitizeAccessControlRequestHeaderValues(headerValues string) string {
parts := strings.Split(headerValues, ",")
valid := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
validPart := true
for _, c := range v {
if !isTokenChar(c) {
validPart = false
break
}
}
if validPart {
valid = append(valid, v)
}
}
return strings.Join(valid, ", ")
}
+120
View File
@@ -0,0 +1,120 @@
package server
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
"github.com/mostlygeek/llama-swap/internal/config"
)
func TestServer_ExtractAPIKey(t *testing.T) {
basicHeader := func(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
}
cases := []struct {
name string
auth string
xapi string
want string
}{
{"none", "", "", ""},
{"bearer", "Bearer tok123", "", "tok123"},
{"basic", basicHeader("user", "pw-key"), "", "pw-key"},
{"x-api-key", "", "xkey", "xkey"},
{"basic beats bearer", basicHeader("u", "bk"), "", "bk"},
{"bearer beats x-api-key", "Bearer btok", "xkey", "btok"},
{"malformed basic falls back to x-api-key", "Basic !!!notbase64", "xkey", "xkey"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if c.auth != "" {
r.Header.Set("Authorization", c.auth)
}
if c.xapi != "" {
r.Header.Set("x-api-key", c.xapi)
}
if got := extractAPIKey(r); got != c.want {
t.Errorf("extractAPIKey() = %q, want %q", got, c.want)
}
})
}
}
func TestServer_SanitizeAccessControlRequestHeaders(t *testing.T) {
cases := []struct {
in string
want string
}{
{"Content-Type, Authorization", "Content-Type, Authorization"},
{" X-Custom , Accept ", "X-Custom, Accept"},
{"Valid, Bad Header", "Valid"},
{"Bad@Header", ""},
{"", ""},
}
for _, c := range cases {
if got := sanitizeAccessControlRequestHeaderValues(c.in); got != c.want {
t.Errorf("sanitize(%q) = %q, want %q", c.in, got, c.want)
}
}
}
func TestServer_IsTokenChar(t *testing.T) {
for _, r := range "abcXYZ0129!#$%&'*+-.^_`|~" {
if !isTokenChar(r) {
t.Errorf("isTokenChar(%q) = false, want true", r)
}
}
for _, r := range " @()/\t\"" {
if isTokenChar(r) {
t.Errorf("isTokenChar(%q) = true, want false", r)
}
}
}
func TestServer_AuthMiddleware(t *testing.T) {
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "" || r.Header.Get("x-api-key") != "" {
t.Error("auth headers leaked to upstream")
}
w.WriteHeader(http.StatusOK)
})
t.Run("no keys configured passes through", func(t *testing.T) {
mw := CreateAuthMiddleware(config.Config{})
w := httptest.NewRecorder()
mw(final).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
cfg := config.Config{RequiredAPIKeys: []string{"secret"}}
t.Run("valid key", func(t *testing.T) {
mw := CreateAuthMiddleware(cfg)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Authorization", "Bearer secret")
w := httptest.NewRecorder()
mw(final).ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
t.Run("invalid key", func(t *testing.T) {
mw := CreateAuthMiddleware(cfg)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Authorization", "Bearer wrong")
w := httptest.NewRecorder()
mw(final).ServeHTTP(w, r)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", w.Code)
}
if w.Header().Get("WWW-Authenticate") == "" {
t.Error("missing WWW-Authenticate header")
}
})
}
+176
View File
@@ -0,0 +1,176 @@
package server
import (
"fmt"
"net/http"
"strings"
"sync"
"github.com/fxamacker/cbor/v2"
"github.com/klauspost/compress/zstd"
)
// ReqRespCapture is a stored request/response pair for a single metered request.
type ReqRespCapture struct {
ID int `json:"id"`
ReqPath string `json:"req_path"`
ReqHeaders map[string]string `json:"req_headers"`
ReqBody []byte `json:"req_body"`
RespHeaders map[string]string `json:"resp_headers"`
RespBody []byte `json:"resp_body"`
}
// captureFields is a bitmask controlling what a route stores in a ReqRespCapture.
type captureFields uint
const (
captureReqHeaders captureFields = 1 << iota
captureReqBody
captureRespHeaders
captureRespBody
)
const (
captureReqAll = captureReqHeaders | captureReqBody
captureRespAll = captureRespHeaders | captureRespBody
captureAll = captureReqAll | captureRespAll
)
// captureFieldsByPath overrides the default capture mask for routes carrying
// large binary payloads (audio/image) where storing the full body is wasteful.
var captureFieldsByPath = map[string]captureFields{
"/v1/audio/speech": captureReqAll | captureRespHeaders,
"/v1/audio/voices": captureReqHeaders | captureRespAll,
"/v1/audio/transcriptions": captureReqHeaders | captureRespHeaders | captureRespBody,
"/v1/images/generations": captureReqAll | captureRespHeaders,
"/v1/images/edits": captureReqHeaders | captureRespHeaders,
"/sdapi/v1/txt2img": captureReqAll | captureRespHeaders,
"/sdapi/v1/img2img": captureReqHeaders | captureRespHeaders,
}
// captureFieldsFor returns the capture mask for a request path. Unlisted routes
// (the OpenAI-compatible JSON endpoints) capture everything.
func captureFieldsFor(path string) captureFields {
if cf, ok := captureFieldsByPath[path]; ok {
return cf
}
return captureAll
}
// zstdEncOptions are the shared zstd encoder options for maximum compression.
var zstdEncOptions = []zstd.EOption{
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
}
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
var zstdEncPool = &sync.Pool{
New: func() interface{} {
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
return enc
},
}
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
var zstdDecPool = &sync.Pool{
New: func() interface{} {
dec, _ := zstd.NewReader(nil)
return dec
},
}
// compressCapture marshals a ReqRespCapture to CBOR and compresses it with zstd.
// Returns the compressed bytes and the original CBOR byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
cborBytes, err := cbor.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
zenc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(zenc)
return zenc.EncodeAll(cborBytes, nil), len(cborBytes), nil
}
// decompressCapture decompresses zstd-compressed CBOR into a ReqRespCapture.
func decompressCapture(data []byte) (*ReqRespCapture, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
cborBytes, err := dec.DecodeAll(data, nil)
if err != nil {
return nil, fmt.Errorf("decompress capture: %w", err)
}
var capture ReqRespCapture
if err := cbor.Unmarshal(cborBytes, &capture); err != nil {
return nil, fmt.Errorf("unmarshal capture: %w", err)
}
return &capture, nil
}
// addCapture compresses and stores a capture in the cache. Returns true if the
// capture was stored.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) bool {
if !mp.enableCaptures {
return false
}
compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return false
}
if err := mp.captureCache.Add(capture.ID, compressed); err != nil {
mp.logger.Warnf("capture %d too large (%d bytes), skipping: %v", capture.ID, len(compressed), err)
return false
}
compressionRatio := (1 - float64(len(compressed))/float64(uncompressedBytes)) * 100
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
return true
}
// getCaptureByID decompresses and unmarshals a capture by ID. Returns nil if
// the capture is not found or decompression fails.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
if mp.captureCache == nil {
return nil
}
data, err := mp.captureCache.Get(id)
if err != nil {
return nil
}
capture, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return capture
}
// sensitiveHeaders lists headers that are redacted in captures.
var sensitiveHeaders = map[string]bool{
"authorization": true,
"proxy-authorization": true,
"cookie": true,
"set-cookie": true,
"x-api-key": true,
}
// headerMap flattens an http.Header to a single-value map.
func headerMap(h http.Header) map[string]string {
m := make(map[string]string, len(h))
for key, values := range h {
if len(values) > 0 {
m[key] = values[0]
}
}
return m
}
// redactHeaders replaces sensitive header values in-place with "[REDACTED]".
func redactHeaders(headers map[string]string) {
for key := range headers {
if sensitiveHeaders[strings.ToLower(key)] {
headers[key] = "[REDACTED]"
}
}
}

Some files were not shown because too many files have changed in this diff Show More