Compare commits

...

69 Commits

Author SHA1 Message Date
George 0ab214d1c8 perf: add vendor-agnostic GPU monitoring for Windows (experimental) (#779)
Add GPU monitoring support for AMD and Intel GPUs on Windows using
D3DKMT (DirectX) and PDH performance counters.

- Add PDH-based GPU utilization via \GPU Engine(*)\Utilization
Percentage counter, summing all engine types per adapter (3D, Compute,
Copy, Video).
- Add D3DKMT bindings for adapter enumeration, memory segments, and
adapter perf data.
- Use PDH as primary utilization source (works on all vendors), with
D3DKMT RunningTime as fallback for systems without PDH counters.
- Prefer nvidia-smi when available, fall back to D3DKMT + PDH for
AMD/Intel.
- Backend priority: nvidia-smi -> D3DKMT + PDH -> ErrNoGpuTool.

Verified on AMD 7900XTX GPU with llama.cpp Vulkan & ROCm backend: GPU
utilization correctly shows ~99% during inference, ~0-2% when idle.

---

LLM disclosure: GLM 5.1 & Kimi K2.6 have been used extensively during
exploration and coding to the point that the LLM's wrote over 3/4 of the
code, and I have done additional verification myself.
As such, it should be considered experimental.
Additional verification is needed.

I have tested it on my 7900XTX system with Windows 11, and it works
correctly, but as I only have this one rig, I cannot verify it
everywhere.
2026-06-16 21:49:09 -07:00
Benson Wong d07b063ab6 internal/server,shared: support request metadata (#850)
- add support for http handlers in the request chain to append metadata
to the request
- metrics middleware will include metadata in the activity log 
- update Activity UI to support metadata, drag sort columns
- update Activity UI capture dialog to use more screen space

Updates #834
2026-06-16 21:44:55 -07:00
Benson Wong 826210dac9 .coderabbit.yaml: disable unit_tests 2026-06-16 10:10:17 -07:00
Benson Wong 6cf1317341 schedule,shared: move concurrency 429 limits into scheduler code (#849)
- make concurrency limiting the scheduler.Scheduler's responsibility
- eliminate the separate concurrency limit middleware 
- move concurrencyLimit logic into scheduler.FIFO to maintain backwards compatibility
- add HTTPError from #834 

Updates #834
2026-06-15 22:35:12 -07:00
Wojciech 8e84b2ec4f README.md: add macports install option to README (#848) 2026-06-15 15:58:24 -07:00
Benson Wong ed77385d08 ui: improve manual model load and cancel (#847)
- When a model is manually loaded show a cancel buttton and a queued
status
- Implement cancellation in scheduler.Scheduler interface and FIFO
scheduler
- Add cache bust query parameter to bypass browser cache

Fixes #844
2026-06-14 13:38:10 -07:00
Benson Wong 92b90447e8 Model capabilities 734 (#842)
internal/config,server: implement model capabilities

- define the capabilities of a model using a simple config block on the
model
- v1/models renders out capabilities to be compatible with openrouter,
huggingface chat, and mistral formats for broader compatibility
- add support for capabilities in UI

Fixes #734
2026-06-13 23:23:19 -07:00
Benson Wong 62aea0e83d internal/router,server,shared: refactor auth, libs (#839)
- refactor shared http functionality into internal/shared/http.go
- remove stripping of Authorization and x-api-key
- add Request Context middleware to internal/server
- add /ui and /metrics behind auth middleware, fixes #717

Fix #717
Updates: #834
2026-06-13 10:19:04 -07:00
Benson Wong 8c660dcb90 main: gofmt 2026-06-11 22:16:39 -07:00
Benson Wong f6877b8175 main: show message when listening on network (#836)
fixes: #739
2026-06-11 22:15:14 -07:00
Benson Wong 9b3a33d7b9 Implement new scheduler (#823)
- introduce internal/router/scheduler to decouple routing, swapping and
queuing into interface contracts.
- introduce a new `routing` configuration section that supersedes
`matrix` and `group` while maintaining backwards compatibility
- add FIFO scheduler with prioritized queuing 
- add internal/router/design.md as developer documentation on
implementing new schedulers and routers

Fixes #797
2026-06-10 20:34:25 -07:00
Benson Wong 0cfe5a6639 Makefile,internal: fix websocket regression and other small things (#830)
- fix websocket regression and add test to prevent in the future
- fix staticheck errors
- remove proxy package remnants from Makefile 

fix #829
2026-06-09 21:37:53 -07:00
Benson Wong 44e1501e81 internal/process,server: fix unload regression (#828)
In v221 the shutdown behaviour was refactored so shutdown behaviour was
more reliable in stopping a process group. This exposed an existing bug
where the unload API had a timeout of 0 that snuck in during the big
refactor.

- set a default timeout of 10 seconds for unloads called via the API
- add logging around shutdown routine

updates: #807, #808
fixes: #827
2026-06-09 20:49:58 -07:00
Benson Wong 46cea36bc2 proxy: remove legacy code. Thanks champ 🫡 (#822)
Fixes #820
2026-06-06 21:00:30 -07:00
Benson Wong ccfba0df28 docker: fix arm64 cpu image downloading amd64 llama-swap binary (#819)
Replace TARGETARCH build-arg with runtime arch detection via uname -m.
BuildKit's TARGETARCH injection was unreliable for the multi-arch cpu
build, causing the arm64 image variant to download and embed the x86_64
llama-swap binary — resulting in "exec format error" on arm64 hosts.

With QEMU user-space emulation, uname -m correctly returns aarch64
inside an arm64 container build, so the download always fetches the
right binary for the actual target architecture. Also adds --fail to
curl so HTTP 404s produce a build error instead of silently embedding an
HTML error page.

fixes #818

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-04 14:26:21 -07:00
Benson Wong ddfae90b19 Change cron schedule for container builds
Shift the non-unified container builds about 8 hours after the llama.cpp's projects container publishing window. The llama.cpp containers take a few hours to build and publish and 8 hours is expected to be enough time to remain fresh. 

Additionally, add an extra build at 18:00 in case the 12:00 one does not pick things up. The container builds on the llama-swap side are cheap (just injecting llama-swap binary) so it is fine to run them a bit more frequently.
2026-06-04 11:00:43 -07:00
Benson Wong 29d3d9ba20 perf: add macOS GPU monitoring via mactop and ioreg (#816)
Implement performance monitoring on OSX for Apple Silicon hardware. 

The implementation uses a combination of mactop and ioreg. If mactop is
installed (`brew install mactop`) it is used in a headless cli mode to
stream usage metrics. mactop hooks into unpublished(?) C based APIs in
OSX. Rather than introduce a cgo dependency into llama-swap's build
chain only for darwin I opted to go the external process route.

ioreg, which comes bundled with OSX is used as the fallback. It does not
provide temperature and power usage data but is able to show accurate
GPU and memory utilization.

Updates #771, #814
2026-06-03 21:51:03 -07:00
Benson Wong 9be9a87fa0 internal/process: improve windows shutdown behaviour (#808)
Add Windows specific shutdown code paths so stopping of child processes
is more reliable:

- stopping llama-swap won't leave behind any child processes it created
- uses Job Objects in Windows so the whole llama-swap tree is closed by
the os
- add procCtx to baseRouter. It replaces shutdownCtx as a signal for
managing lifetime state.
- shutdownCtx is only used by the router to stop handling new requests
during shutdown
- improve debug logging to make it easier to trace source of issues

Fixes #804
Updates #807
2026-06-01 00:45:30 -07:00
Benson Wong 6ea551362e process,router: make model shutdown and load-streaming robust
Note: The original proxy/process_unix.go had a noop for setProcAttributes
so it also did not stop grandchildren processes. This patch adds that capability 
and improves reliability.

--

Stop() no longer hangs on a shell wrapper that forks the real binary.
The upstream is built with exec.CommandContext + cmd.Cancel +
cmd.WaitDelay, so cmd.Wait() returns even when a forked grandchild
inherits the stdout/stderr pipes. killProcess sends the stop signal
directly (not by cancelling the context) so cmd.WaitDelay measures from
process exit and never silently caps the caller's graceful timeout.

The upstream is also started in its own process group (Setpgid) on Unix,
so the graceful SIGTERM — and the SIGKILL escalation after the timeout —
are delivered to the whole group via the negative PID. A forked
grandchild is reaped with its parent instead of leaking as an orphan.

The loading-spinner SSE goroutine can no longer panic when it outlives
the request. net/http recycles the response writer via Reset(nil) once
ServeHTTP returns; the orphaned goroutine then flushed against a
nil-backed writer and crashed with a SIGSEGV. A release() fence on
loadingWriter lets any in-flight write finish then short-circuits later
writes/flushes, and all three ServeHTTP select branches run a
finishLoading helper (cancelLoad, waitForCompletion, release) before the
writer is reclaimed.

- internal/process: exec.CommandContext + WaitDelay, Setpgid process
groups, group-wide SIGTERM/SIGKILL teardown
- internal/router: release() fence + finishLoading on loadingWriter

fixes #804
2026-05-31 10:11:12 -07:00
Benson Wong 03d58e53fa Add load testing tool to the UI (#805)
Wouldn't it be nice to test the performance, swapping and concurrency
from the UI? Now we can! This is a port of `cmd/test-concurrency` into the UI

Here's a demo of it working with a swap matrix: 

https://github.com/user-attachments/assets/b6bb12ec-0381-46f1-a6b8-27d1c3c0ddb3
2026-05-30 17:04:30 -07:00
Luiszzzor c790d0ee03 fix: update the concurrency middleware to respond with a JSON payload (#798)
update the concurrency middleware to respond with a JSON payload instead
of plain text when the request limit is reached to be compatible with
openai api standard

---------

Co-authored-by: Ludwik <l.czarnota@samsung.com>
2026-05-29 23:59:32 -07:00
Benson Wong 4ca9c478a2 Makefile,internal/server: various release tweaks 2026-05-29 15:27:08 -07:00
Benson Wong 146a9eab24 ui-svelte: update build directory (#801)
Fixes #799
2026-05-29 14:45:05 -07:00
Benson Wong 02e015fa49 Introduce new routing backend (#790)
This is a huge backend change that essentially started with rewriting
the concurrency handling for processes and blew up to a refactor of the
entire application. In short these are the improvements:

**Better state and life cycle management:** 

Life cycle management of processes has always been the trickiest part of
the code. Juggling mutex locks between multiple locations to reduce race
conditions was complex. Too complex for my feeble brain to build a
simple mental model around as llama-swap gained more features. All of
that has been refactored. Most of the locks are gone, replaced with a
single run() that owns all state changes. There is one place to start
from now to understand and extend routing logic.

The improved life cycle management makes it easier to implement more
complex swap optimization strategies in the future like #727.

**Collation of requests:**

llama-swap previously handled requests and swapping in the order they
came in. For example requests for models in this order ABCABC would
result in 5 swaps. Now those requests are handled in this order AABBCC.
The result is less time waiting for swap under a high churn request
queue. This fixes #588 #612.

A possible future enhancement is to support a starvation parameter so
swap can be forced when models have been waiting too long.

**Shared base implementation for groups and swap matrix:** 

During the refactor it became clear that much of the swapping logic was
shared between these two implementations. That is not surprising
considering the swap matrix was added many moons after groups. Now they
share a common base and their specific swap strategies are implemented
into the swapPlanner interface.

Requests for bespoke or specific swapping scenarios is a common theme in
the issues. Now users can implement whatever bespoke and weird swapping
strategy they want in their own fork. Just ask your agent of choice to
implement swapPlanner. I'll still remaining more conservative on what
actually lands in core llama-swap and will continue to evaluate PRs if
the changes is good for everyone or just one specific use case.

**AI / Agentic Disclosure:** 

I paid very close attention to the low level swap concurrency design and
implementation. It's important to keep that essential part reliable,
boring and no surprises. Backwards compatibility was also maintained,
even the one way non-exclusive group model loading behaviour that people
have rightly pointed out be a weird design decision.

With the underlying swap core done the web server, api and UI sitting on
top were largely ported over with Claude Code and Opus 4.7 in multiple
phases. If you're curious I kept the changes in docs/newrouter-todo.md.
I did several passes to make sure things weren't left behind.

However, even frontier LLMs at the time of this PR still make small
decisions that don't make a lot of sense. They get shit wrong all the
time, just in small subtle way.

That said, there's likely to be some new bugs introduced with this
massive refactor. I'm fairly confident that there's no major
architectural flaws that would cause goal seeking agents to make dumb,
ugly code decisions.

For a little while the legacy llama-swap will be available under
cmd/legacy/llama-swap. The plan is to eventually delete that entry point
as well as the proxy package.

On a bit of a personal note, this PR is exciting and a bit sad for me. I
hand wrote much of the original code and this PR ultimately replaces
much of it. While the old code served as a good reference for the agent
to implement the new stuff it still a bit sad to eventually delete it
all.
2026-05-28 21:47:01 -07:00
Cr4xy 63bc266395 Add new power draw column header for rocm-smi monitoring (#788)
# Overview
This patch fixes
https://github.com/mostlygeek/llama-swap/pull/775#issuecomment-4535303706
and removes some unnecessary `break` statements.

## The third variant now also works with power draw:
`
device,Device Name,Device ID,Device Rev,Subsystem ID,GUID,Temperature
(Sensor edge) (C),Temperature (Sensor junction) (C),Temperature (Sensor
memory) (C),Average Graphics Package Power (W),GPU use (%),GPU Memory
Allocated (VRAM%),GPU Memory Read/Write Activity (%),Memory
Activity,Avg. Memory Bandwidth,VRAM Total Memory (B),VRAM Total Used
Memory (B),Card Series,Card Model,Card Vendor,Card SKU,Node ID,GFX
Version
`
<img width="1121" height="315" alt="image"
src="https://github.com/user-attachments/assets/4b908c4d-2401-4dfe-9bac-e7aa770cfb42"
/>

## Old variants:
`
device,Device Name,Device ID,Device Rev,Subsystem ID,GUID,Temperature
(Sensor edge) (C),Temperature (Sensor junction) (C),Temperature (Sensor
memory) (C),Fan speed (level),Fan speed (%),Fan RPM,Current Socket
Graphics Package Power (W),GPU use (%),GPU Memory Allocated (VRAM%),GPU
Memory Read/Write Activity (%),Memory Activity,VRAM Total Memory
(B),VRAM Total Used Memory (B),Card Series,Card Model,Card Vendor,Card
SKU,Node ID,GFX Version
`
<img width="1118" height="308" alt="image"
src="https://github.com/user-attachments/assets/b236e0cd-4505-42e5-b497-cff62c720e3d"
/>

`
device,Device Name,Device ID,Device Rev,Subsystem ID,GUID,Temperature
(Sensor edge) (C),Current Socket Graphics Package Power (W),GPU use
(%),GPU Memory Allocated (VRAM%),Memory Activity,VRAM Total Memory
(B),VRAM Total Used Memory (B),Card Series,Card Model,Card Vendor,Card
SKU,Node ID,GFX Version
`
<img width="1120" height="312" alt="image"
src="https://github.com/user-attachments/assets/1adde1c3-5f35-4db4-ba13-65751ac076e8"
/>
2026-05-25 11:36:16 -07:00
Cr4xy 636b53e70f Improve rocm-smi performance monitoring (#775)
Fix hardcoded indices for rocm-smi.
2026-05-20 17:59:49 -07:00
gatkisson 59cd3b690d Added Windows performance monitoring using nvidia-smi (#773)
updates: #596, #771
2026-05-18 11:02:03 -07:00
Benson Wong 5d1e62d224 Disable auto review feature in coderabbit config 2026-05-18 10:40:21 -07:00
Benson Wong dbb869d019 Increase inactivity thresholds for stale issues
Updated stale issue and close messages to reflect new inactivity thresholds.
2026-05-17 22:52:58 -07:00
Benson Wong 26bb17e57e config.example.yaml: Improve matrix vs groups info
For some use cases groups are simpler to use. Note this in the
documentation that it is still fully supported.
2026-05-17 15:59:25 -07:00
Benson Wong 2982dd3d40 ui-svelte: update link to performance discussion thread 2026-05-17 11:45:56 -07:00
knguyen298 79dc87f881 Add ROCm stats via rocm-smi (#767)
Add ROCm GPU stats support using `rocm-smi`.
2026-05-17 07:58:26 -07:00
krzychdre b2fcc2daa1 ui-svelte: fix cached tokens total counting -1 sentinel (#760)
The backend uses cache_tokens=-1 as a sentinel for endpoints that don't
report cache stats (embeddings, vLLM). The activity table correctly
renders these as "-", but the totals widget summed the sentinels
directly, so each such request subtracted 1 from the displayed total.

- clamp cache_tokens with Math.max(0, ...) when reducing
2026-05-15 14:42:44 -07:00
cdwaage 6a9c4efc8f fix: use --loop instead of -loop for nvidia-smi (driver 540+ compat) (#759) 2026-05-15 13:20:29 -07:00
Benson Wong 0c813e44d1 ui-svelte: package updates 2026-05-14 21:56:04 -07:00
Benson Wong fe71e8a6ea proxy,ui-svelte: improve support for v1/messages and v1/responses (#758)
This improves the support for activity logging from the v1/responses and
v1/messages endpoints.

- add chat endpoint selection to Playground > Chat > Settings
- improve metrics extraction for streaming v1/messages and v1/responses
endpoints (tested with llama-server)

Fixes #742
2026-05-14 21:53:57 -07:00
Benson Wong aac7b8745a ci: set go-version-file in release workflow 2026-05-13 22:12:02 -07:00
Benson Wong 4e606feff0 ci: fix workflow bugs in release and go-ci
- release.yml: merge orphaned `uses:` into the Checkout step
- go-ci.yml: skip simple-responder build when restored from cache
2026-05-13 21:48:27 -07:00
Benson Wong a4b91e08cf Changes and fixes before the release (docs/small tweaks) (#750)
- update README.md with new docker instructions
- update docs/configuration.md
- update .github/workflows to have pinned action versions
- gofmt events package
- fix small bugs in CI scripts
- reduce config options for internal/perf/monitor and config. A ring buffer is used to keep 1hr of entries at max 5s granularity. For long term stats use prometheus monitoring on /metrics

Fixes #744
2026-05-13 21:18:19 -07:00
David Soušek 3e3646f9f9 perf: ignore LACT devices reporting zero VRAM (#753)
Ignore LACT devices that report zero total VRAM.

Some virtual GPUs on headless VMs report `MemTotalMB == 0` through LACT,
which makes them appear in performance monitoring despite not providing
useful memory data. Skip those entries so only usable GPU devices are
reported.

This makes performance monitoring cleaner on headless VMs with virtual
GPUs that report zero VRAM.

Co-authored-by: David Soušek <david.sousek@intelogy.co.uk>
2026-05-13 10:03:54 -07:00
rhtenhove a01afe261b ci: use manifest-aware cleanup action for multi-arch :cpu (#751)
actions/delete-package-versions can't see OCI manifest lists. When the
cpu build pushes a multi-arch image, the registry gets a tagged index
plus one untagged per-platform manifest per arch. The cleanup step with
`delete-only-untagged-versions: true` then deletes the per-platform
children, leaving the index dangling — `docker pull
ghcr.io/mostlygeek/llama-swap:cpu` 404s on the referenced sha.

Swap to dataaxiom/ghcr-cleanup-action, which inspects tagged manifest
lists first and excludes their children from deletion. Single-arch
backends behave the same as before.

Fix #746
2026-05-12 18:04:46 -07:00
rhtenhove 174e8562aa Multi arch cpu (#746)
Encountered a similar problem as in
https://github.com/mostlygeek/llama-swap/issues/709 but in my case I
only needed the :cpu version.

So decided to add the github action to build arm64 combined with the
amd64 version on the same :cpu tag. Already tested it from this fork:
ghcr.io/rhtenhove/llama-swap:cpu and it works perfectly fine.

Adding GPU support is a whole other beast, needing quite a bit more work
and isn't something I can test.
2026-05-11 21:03:48 -07:00
Abdulazez A. 085b54bc88 proxy: fix data race in /running endpoint and typo in error message (#748)
## Problem

The `/running` endpoint in `listRunningProcessesHandler` reads
`process.state` directly without holding `stateMutex`. Meanwhile,
`swapState()` writes to `process.state` while holding the write lock.
This is a data race flagged by the Go race detector.

Also fixes a minor typo: "processes was in state" → "process was in
state".

## Fix

- `proxymanager.go`: Replace `process.state` with
`process.CurrentState()` which acquires `stateMutex.RLock()` before
reading.
- `process.go`: Fix typo in error message.

## Verification

- `gofmt -l` — clean
- `go test -run "TestProcessGroup_|TestProxyManager_" ./proxy/` — all
pass
- `go test ./proxy/config/... ./proxy/cache/...
./proxy/configwatcher/...` — all pass
2026-05-11 12:49:18 -07:00
bankjaneo 2be3416baa ui: add auto theme switch mode based on system theme (#741)
Add system theme detection with automatic switching when OS theme
changes.

- Add ThemeMode type with "light", "dark", and "system" options
- Add system theme listener using matchMedia API
- Update theme toggle to cycle through System → Light → Dark
- Add combined sun/moon icon for system theme mode
- Migrate existing theme preferences to new format
2026-05-09 20:22:18 -07:00
Benson Wong 7e3e94a08a proxy,ui: add performance monitoring with Prometheus metrics (#743)
Add a comprehensive performance monitoring system that collects CPU, memory, swap, load average, network IO, and GPU stats. Provides both a REST API for the UI and a Prometheus /metrics endpoint.

Backend changes:
- New internal/perf package with configurable interval-based stats collection
- GPU monitoring via LACT (Unix socket) and nvidia-smi fallback on Linux
- Ring buffer (internal/ring) for time-series stat storage
- Prometheus /metrics endpoint with all system and GPU metrics
- Moved LogMonitor to internal/logmon package
- New PerformanceConfig for hot-reloadable monitoring settings
- REST /api/performance endpoint replacing SSE streaming

UI changes:
- New Performance page with real-time charts for CPU, memory, GPU, and network
- Reusable PerformanceChart component
- LLAMA_SWAP_URL environment variable support
- Improved capture dialog display

Other:
- Example Grafana dashboard for Prometheus metrics
- monitor-test standalone binary
- Config schema and example updates

fixes #596
2026-05-09 13:29:22 -07:00
Wim Vander Schelden e261745c66 proxy: add versionless API endpoint (#733)
Add versionless endpoints under v/ to support upstream peers that 
do not use the v1/ prefix.

Fixes #728.
2026-05-03 13:47:38 -07:00
Benson Wong 11b7913287 llama-swap.go: remove debounce, replace fmt.Printlns (#731)
small fixes to clean up the main(): 

- remove the debounced config reload 
- replace fmt.Println with a proxy.LogMonitor for consistency
2026-05-02 16:28:53 -07:00
Marcus c79114d40a proxy: fix logger not checking matrix for processes
Fix matrix not being used to search for a logger causing /logs/stream/model_name to return an error
2026-05-01 16:43:20 -07:00
Benson Wong 430166d5eb proxy: fix zero duration for non streaming responses (#723)
Updates #654
2026-04-30 19:51:28 -07:00
Marcus 5b4beaceef fix: ?no-history flag and improve /logs monitoring docs (#721)
- improve logging documentation 
- small tweaks for edge case issues in upstream and log requests
2026-04-30 00:50:36 -07:00
Benson Wong fd3c28ffc5 Refactor Activity Page (#710)
- inference handles to store an activity record for all inference endpoints
- add path, status code, and content type to Activities page
- toggle on/off columns no Activities page 
- add configurable capture level for inference endpoints so large binary blobs are not stored in memory
- store captures in compressed binary format
2026-04-28 20:33:03 -07:00
Quentin Machu a846c4f18c config: remove hard cap on macro length (#718)
Remove macro value limit of 1024 characters
2026-04-28 13:32:54 -07:00
Marcus 5bae33a769 ui-svelte: default theme to user preferred color scheme (#712)
Simple, if not set is localStorage use whatever the user's preferred
color scheme is to start.
2026-04-27 06:44:22 -07:00
Benson Wong 8f4ff01f93 ui-svelte: make it easier to toggle panels in logs view 2026-04-26 22:12:43 -07:00
Benson Wong e8d4384cd2 ui-svelte: support reasoning and reasoning_content (#708)
Support `reasoning` v1/chat/completion delta that vLLM uses.
2026-04-26 13:11:48 -07:00
Benson Wong ce28485be2 ui-svelte: add prompt processing histogram (#705)
Activities page shows histograms for prompt processing and token generation times. 

Fix: #691
Fix: #703
2026-04-25 16:13:07 -07:00
Damir 3cd7837b1f fix: support architecture-specific download URLs in install script (#698)
Just a small fix to include proper llama-swap binary when building the
arm64 architecture.
2026-04-23 18:05:33 -07:00
Benson Wong 0b31ccacc1 ui-svelte: fix histogram calculation (#695)
- Fix the histogram calculation to use server provided generation
tokens/second.
- Move histogram to Activities page where it can exist with the rest of
the token metrics

Fixes #681
2026-04-22 23:42:39 -07:00
Bryan Gahagan 5938dbee8f Push unified docker images on scheduled runs (#694)
Fixes #693
2026-04-22 20:46:51 -07:00
Benson Wong 66639e83f7 proxy: replace fsnotify with stat-poll watcher and add SIGHUP reload (#685)
The fsnotify-based config watcher does not work reliably when the config
file is bind-mounted into a Docker container as an individual file, and
mishandles k8s ConfigMap projections (atomically swapped symlinks).
Replace it with a small os.Stat-polling watcher and add SIGHUP as an
explicit reload signal.

- new proxy/configwatcher package: 2s os.Stat poller, follows symlinks,
  fires on mtime/size change and on missing -> present transitions
- SIGHUP triggers reload unconditionally (works without --watch-config)
  via the same ConfigFileChangedEvent pipeline so the UI sees identical
  state transitions
- watcher goroutine now exits cleanly on shutdown via a context
- drop github.com/fsnotify/fsnotify dependency

fixes #682
2026-04-21 23:21:48 -07:00
Benson Wong 625b296720 docker/unified: add uv via pip install (#681)
Install uv after the cpp tool binaries are copied and before the
llama-swap binary, enabling `uv run` usage for Python-based inference
backends like vLLM.

- add python3-pip to runtime apt installs
- add `pip install uv --break-system-packages` after cpp installs

fixes #628

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-20 20:55:51 -07:00
Benson Wong 231e62291c proxy: fix matrix race and process stop bug (#677)
- matrix.go change logic to consider any proxy.Process not in
StateStopped or StateShutdown
- process.StopImmediately, and Stop() which called it had a subtle bug
where it only handled state transitions from StateReady to
StateStopping. StateStarting -> StateStopping was ignored completely.

fix: #670
2026-04-20 00:21:11 -07:00
Benson Wong 57ac666598 .github/workflows: tweak push ghcr conditional (#676) 2026-04-19 13:56:26 -07:00
Benson Wong 69728301f5 .github/workflows: add toggle for pushing unified images to github (#672)
Add ability to dispatch (manually run) unified container builds in github without push to ghcr.io.
2026-04-19 10:10:48 -07:00
Benson Wong c176fa70f1 docker/unified: add spirv-headers to fix vulkan build (#669) 2026-04-18 12:18:10 -07:00
Benson Wong 5e3c646829 proxy: compress captures with zstd (#668)
The previous captures were saved uncompressed in memory. In agentic
workflows there can be many turns with each request containing the
previous context in the body with a lot of redundant data. Use zstd to
compress the request and response data before keeping a copy of memory.

Results: 

- Average Percentage Saved: 73.19%
- Average Compression Factor: ~6.77:1
2026-04-17 23:29:37 -07:00
Benson Wong c3f0d43e6e proxy: fix race conditions during swap (#667)
I pointed Opus 4.7 (high effort) at proxy.ProcessGroup to identify any
race conditions in the swapping code. It found a race condition where
there is a small window in the fast path for routing a request to a
loaded model. There is a very small window where:

- model M1 is loaded and ready for requests
- a request, R1, for M1 comes in 
- a request, R2, for M2 comes in almost immediately after
- R1 acquires the lock, sees M1 is loaded (fast path), releases the lock
`[race window]` and the request is ready to be forwarded
- the race window occurs between the release of the lock and the request
being forwarded
  - the lock is released so requests can be handled concurrently 
- R2 comes in within the `[race window]`, acquires the lock, triggers a
model swap to M2. stopping M1
- R1 is forwarded to a model that is unloaded or in the process of
shutting down creating an error response

In deployed systems the race window is very small and doesn't happen
often. However with #635 and PR #656 I though this deserved a bit more
attention. It is not concluded that this race is the cause of #635 but
the race is likely to happen more often under sustained or high load.

AI Note: Opus 4.7 x-high effort took about an hour to write the original
patch. With the pattern discovered the fix to matrix.go was very quick.
GLM 5.1 using the previous established patterns was able to easily write
the fix for ProcessGroup.StopProcesses().

Supersedes: #656
Updates: #277, #635
2026-04-17 21:23:17 -07:00
Benson Wong f6cf9f5844 proxy: Refactor tests (#660)
- use YAML for test configurations
- remove most uses of simple-responder, opting to use
process.testHandler

Fixes #655
2026-04-16 22:47:42 -07:00
Benson Wong 121fd93ad8 Makefile: restore linux arm64 targets
Fix #641
2026-04-14 22:05:39 -07:00
191 changed files with 23992 additions and 9918 deletions
+3 -1
View File
@@ -13,8 +13,10 @@ reviews:
docstrings: docstrings:
enabled: false enabled: false
auto_review: auto_review:
enabled: true enabled: false
drafts: false drafts: false
unit_tests:
enabled: false
chat: chat:
auto_reply: true auto_reply: true
issue_enrichment: issue_enrichment:
+5 -5
View File
@@ -11,13 +11,13 @@ jobs:
issues: write issues: write
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/stale@v9 - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f #v10.2.0
with: with:
days-before-issue-stale: 14 days-before-issue-stale: 30
days-before-issue-close: 14 days-before-issue-close: 30
stale-issue-label: "stale" stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity." stale-issue-message: "This issue is stale because it has been open without activity for 30 days. Please remove the stale label if this was an error."
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale." close-issue-message: "This issue was closed because it has been inactive for 30 days since being marked as stale."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
+5 -8
View File
@@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Validate JSON Schema - name: Validate JSON Schema
run: | run: |
@@ -44,13 +44,10 @@ jobs:
echo "✓ config-schema.json is valid" echo "✓ config-schema.json is valid"
- name: Set up Python - name: Set up Go
uses: actions/setup-python@v5 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with: with:
python-version: "3.x" go-version-file: go.mod
- name: Install check-jsonschema
run: pip install check-jsonschema
- name: Validate config.example.yaml against schema - name: Validate config.example.yaml against schema
run: check-jsonschema --schemafile config-schema.json config.example.yaml run: go test ./internal/config/ -run TestConfig_ExampleMatchesSchema -v
+34 -10
View File
@@ -2,13 +2,18 @@ name: Build Containers
on: on:
# time has no specific meaning, trying to time it after # time has no specific meaning, trying to time it after
# the llama.cpp daily packages are published # the llama.cpp daily packages have time to build and publish (~8hr after llama.cpp project's cron)
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml # https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
schedule: schedule:
- cron: "37 5 * * *" - cron: "00 12,18 * * *"
# Allows manual triggering of the workflow # Allows manual triggering of the workflow
workflow_dispatch: workflow_dispatch:
inputs:
dryrun:
description: "Run cleanup step in dry-run mode (log what would be deleted, delete nothing)"
type: boolean
default: false
# Run on workflow file changes (without pushing) # Run on workflow file changes (without pushing)
push: push:
@@ -33,7 +38,7 @@ jobs:
fail-fast: false fail-fast: false
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Free up disk space - name: Free up disk space
if: matrix.platform == 'rocm' if: matrix.platform == 'rocm'
@@ -48,8 +53,18 @@ jobs:
echo "After cleanup:" echo "After cleanup:"
df -h df -h
# QEMU enables arm64 cross-builds on the amd64 GitHub runner.
# Currently only the cpu backend goes multi-arch; the action is a
# no-op for amd64-only builds, so leaving it on for every matrix
# entry keeps the workflow simple.
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Log in to GitHub Container Registry - name: Log in to GitHub Container Registry
uses: docker/login-action@v2 uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 #v4.1.0
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.actor }} username: ${{ github.actor }}
@@ -60,14 +75,23 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }} run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package # actions/delete-package-versions can't see manifest lists: pushing
# see: https://github.com/actions/delete-package-versions/issues/74 # a multi-arch image with `docker buildx --push` creates a tagged OCI
# index plus one untagged per-platform manifest per arch, and
# `delete-only-untagged-versions: true` then nukes the per-platform
# children, leaving the index dangling — `docker pull :cpu` 404s on
# the referenced digest. dataaxiom/ghcr-cleanup-action walks tagged
# manifest lists and excludes their children from deletion.
delete-untagged-containers: delete-untagged-containers:
needs: build-and-push needs: build-and-push
# Skip on forks — the delete API requires package-admin on the
# upstream account and would otherwise red-x every fork CI run.
if: github.repository == 'mostlygeek/llama-swap'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/delete-package-versions@v5 - uses: dataaxiom/ghcr-cleanup-action@cd0cdb900b5dbf3a6f2cc869f0dbb0b8211f50c4 # v1.0.16
with: with:
package-name: 'llama-swap' token: ${{ secrets.GITHUB_TOKEN }}
package-type: 'container' package: llama-swap
delete-only-untagged-versions: 'true' delete-untagged: true
dry-run: ${{ inputs.dryrun || false }}
+6 -6
View File
@@ -31,17 +31,17 @@ jobs:
run-tests: run-tests:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with: with:
go-version: '1.23' go-version-file: go.mod
# cache simple-responder to save the build time # cache simple-responder to save the build time
- name: Restore Simple Responder - name: Restore Simple Responder
id: restore-simple-responder id: restore-simple-responder
uses: actions/cache/restore@v4 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with: with:
path: ./build path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }} key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
@@ -56,11 +56,11 @@ jobs:
# nothing new to save ... skip this step # nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true' if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder id: save-simple-responder
uses: actions/cache/save@v4 uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with: with:
path: ./build path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }} key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- name: Test all - name: Test all
shell: bash shell: bash
run: make test-all run: make test-all
+7 -6
View File
@@ -30,37 +30,38 @@ jobs:
run-tests: run-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with: with:
go-version-file: go.mod go-version-file: go.mod
# Only run in this linux based runner # Only run in this linux based runner
- name: Check Formatting - name: Check Formatting
run: | run: |
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then if [ "$(gofmt -l . | wc -l)" -gt 0 ]; then
gofmt -l . | grep -v 'event/.*_test.go' gofmt -l .
exit 1 exit 1
fi fi
# cache simple-responder to save the build time # cache simple-responder to save the build time
- name: Restore Simple Responder - name: Restore Simple Responder
id: restore-simple-responder id: restore-simple-responder
uses: actions/cache/restore@v4 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with: with:
path: ./build path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }} key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping # necessary for testing proxy/Process swapping
- name: Create simple-responder - name: Create simple-responder
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
run: make simple-responder run: make simple-responder
- name: Save Simple Responder - name: Save Simple Responder
# nothing new to save ... skip this step # nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true' if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder id: save-simple-responder
uses: actions/cache/save@v4 uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae #v5.0.5
with: with:
path: ./build path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }} key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
+9 -9
View File
@@ -20,24 +20,24 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
with: with:
fetch-depth: 0 fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }} ref: ${{ github.event.inputs.tag || github.ref }}
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c #6.4.0
with:
go-version-file: go.mod
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
with: with:
node-version: "24" node-version: "24"
- name: Install dependencies and build UI - name: Build UI
run: | run: |
cd ui-svelte make ui
npm ci
npm run build
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6 uses: goreleaser/goreleaser-action@1a80836c5c9d9e5755a25cb59ec6f45a3b5f41a8 #7.2.1
with: with:
# either 'goreleaser' (default) or 'goreleaser-pro' # either 'goreleaser' (default) or 'goreleaser-pro'
distribution: goreleaser distribution: goreleaser
@@ -61,7 +61,7 @@ jobs:
fi fi
- name: "Trigger tap repository update" - name: "Trigger tap repository update"
uses: peter-evans/repository-dispatch@v2 uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 #4.0.1
with: with:
token: ${{ secrets.TAP_REPO_PAT }} token: ${{ secrets.TAP_REPO_PAT }}
repository: mostlygeek/homebrew-llama-swap repository: mostlygeek/homebrew-llama-swap
+4 -13
View File
@@ -19,24 +19,15 @@ jobs:
run-tests: run-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
working-directory: ui-svelte
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # 6.4.0
with: with:
node-version: '24' node-version: '24'
cache: 'npm' cache: 'npm'
cache-dependency-path: ui-svelte/package-lock.json cache-dependency-path: ui-svelte/package-lock.json
- name: Install dependencies - name: Run UI tests
run: npm ci run: make test-ui
- name: Type check
run: npm run check
- name: Run tests
run: npm test
+9 -4
View File
@@ -36,6 +36,11 @@ on:
type: boolean type: boolean
required: false required: false
default: true default: true
push_to_ghcr:
description: "Push images to ghcr.io"
type: boolean
required: false
default: true
permissions: permissions:
contents: read contents: read
@@ -70,7 +75,7 @@ jobs:
backend: ${{ fromJSON(needs.setup.outputs.matrix) }} backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # 6.0.2
- name: Free up disk space - name: Free up disk space
run: | run: |
@@ -89,11 +94,11 @@ jobs:
# llama-swap-builder (which has ccache warm) to avoid exhausting disk. # llama-swap-builder (which has ccache warm) to avoid exhausting disk.
- name: Set up Docker Buildx - name: Set up Docker Buildx
if: ${{ !env.ACT }} if: ${{ !env.ACT }}
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Log in to GitHub Container Registry - name: Log in to GitHub Container Registry
if: ${{ !env.ACT }} if: ${{ !env.ACT }}
uses: docker/login-action@v3 uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 #v4.1.0
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.actor }} username: ${{ github.actor }}
@@ -116,7 +121,7 @@ jobs:
docker/unified/build-image.sh --${{ matrix.backend }} docker/unified/build-image.sh --${{ matrix.backend }}
- name: Push to GitHub Container Registry - name: Push to GitHub Container Registry
if: ${{ !env.ACT }} if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
run: | run: |
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}" BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
DATE_TAG=$(date -u +%Y-%m-%d) DATE_TAG=$(date -u +%Y-%m-%d)
+3
View File
@@ -5,3 +5,6 @@ dist/
.vscode .vscode
.DS_Store .DS_Store
.dev/ .dev/
# UI build output; placeholder.txt is kept so the go:embed succeeds.
internal/server/ui_dist/*
+3 -1
View File
@@ -21,9 +21,11 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc. - Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written. - Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`. - Run `gofmt -w <file>` before committing to fix any formatting
- Build go binaries into the ./build/ subdirectory
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests. - Use `make test-all` before completing work. This includes long running concurrency tests.
- Use `make test-ui` after making changes to the UI in ui-svelte/
### Commit message example format: ### Commit message example format:
+22 -16
View File
@@ -19,21 +19,17 @@ all: mac linux simple-responder
clean: clean:
rm -rf $(BUILD_DIR) rm -rf $(BUILD_DIR)
proxy/ui_dist/placeholder.txt:
mkdir -p proxy/ui_dist
touch $@
# use cached test results while developing # use cached test results while developing
test-dev: proxy/ui_dist/placeholder.txt test-dev:
go test -short ./proxy/... go test -short ./...
staticcheck ./proxy/... || true staticcheck ./... || true
test: proxy/ui_dist/placeholder.txt test:
go test -short -count=1 ./proxy/... go test -short -count=1 ./internal/...
# for CI - full test (takes longer) # for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt test-all:
go test -race -count=1 ./proxy/... go test -race -count=1 ./internal/...
ui/node_modules: ui/node_modules:
cd ui-svelte && npm install cd ui-svelte && npm install
@@ -41,6 +37,7 @@ ui/node_modules:
# build react UI # build react UI
ui: ui/node_modules ui: ui/node_modules
cd ui-svelte && npm run build cd ui-svelte && npm run build
touch internal/server/ui_dist/placeholder.txt
# Build OSX binary # Build OSX binary
mac: ui mac: ui
@@ -48,17 +45,22 @@ mac: ui
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64 GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary # Build Linux binary
linux: ui linux: linux-arm64 linux-amd64
@echo "Building Linux binary..."
linux-amd64: ui
@echo "Building Linux AMD64 binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64 GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
linux-arm64: ui
@echo "Building Linux ARM64 binary..."
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary # Build Windows binary
windows: ui windows: ui
@echo "Building Windows binary..." @echo "Building Windows binary..."
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
# for testing proxy.Process # for testing with real external processes
simple-responder: simple-responder:
@echo "Building simple responder" @echo "Building simple responder"
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
@@ -92,5 +94,9 @@ wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy" @echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
test-ui:
cd ui-svelte && npm ci && npm run check && npm test
# Phony targets # Phony targets
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy .PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
.PHONE: linux linux-arm64 linux-amd64
+48 -18
View File
@@ -20,6 +20,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `v1/chat/completions` - `v1/chat/completions`
- `v1/responses` - `v1/responses`
- `v1/embeddings` - `v1/embeddings`
- `v1/models` - list available models
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- `v1/audio/voices` - `v1/audio/voices`
@@ -39,16 +40,26 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ llama-swap API - ✅ llama-swap API
- `/ui` - web UI - `/ui` - web UI
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring - `POST /api/models/unload` - manually unload all running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `POST /api/models/unload/:model_id` - unload a specific model
- `/logs` - remote log monitoring
- `GET /logs` returns buffered plain text logs.
- If `Accept: text/html` is sent, `/logs` redirects to `/ui/`.
- `GET /logs/stream` keeps the connection open for live log streaming.
- Stream endpoints send buffered history first by default; add `?no-history` to stream only new lines.
- `GET /logs/stream/proxy` streams proxy logs only.
- `GET /logs/stream/upstream` streams upstream process logs only.
- `GET /logs/stream/{model_id}` streams logs for one model (including IDs with slashes, like `author/model`).
- `/health` - just returns "OK" - `/health` - just returns "OK"
- `/metrics` - system and GPU metrics for prometheus
- ✅ API Key support - define keys to restrict access to API endpoints - ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable - ✅ Customizable
- Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643)) - Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
- Automatic unloading of models after timeout by setting a `ttl` - Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together - Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) - Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
- Apply filters to requests to control inference with `stripParams`, `setParams` and `setParamsByID`
### Web UI ### Web UI
@@ -77,15 +88,32 @@ Real time log streaming:
llama-swap can be installed in multiple ways llama-swap can be installed in multiple ways
1. Docker 1. Docker
2. Homebrew (OSX and Linux) 2. Homebrew (macOS and Linux)
3. WinGet 3. MacPorts (macOS)
4. From release binaries 4. WinGet
5. From source 5. From release binaries
6. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) ### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md). Two types of container images are built nightly for llama-swap:
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
1. A unified container with llama-server, ik-llama-server, stable-diffusion.cpp, whisper.cpp and llama-swap built from source. This is only available for cuda and vulkan but has more capabilities. This one is recommended for use.
2. A legacy image that is based on llama.cpp's images and llama-swap copied into the container. Use this one if you prefer to stay close to llama.cpp's container images.
#### Unified container (Recommended)
```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:unified-cuda
# run with a custom configuration and models directory
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/etc/llama-swap/config/config.yaml \
ghcr.io/mostlygeek/llama-swap:unified-cuda
```
#### Legacy container
```shell ```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda $ docker pull ghcr.io/mostlygeek/llama-swap:cuda
@@ -95,14 +123,6 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \ -v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \ -v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda ghcr.io/mostlygeek/llama-swap:cuda
# configuration hot reload supported with a
# directory volume mount
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
-v /path/to/config:/config \
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
``` ```
<details> <details>
@@ -136,6 +156,16 @@ brew install llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080 llama-swap --config path/to/config.yaml --listen localhost:8080
``` ```
### MacPorts (macOS)
> [!NOTE]
> Maintained by MacPorts community - [llama-swap port](https://ports.macports.org/port/llama-swap). It is not an official part of llama-swap.
```shell
sudo port install llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
### WinGet Install (Windows) ### WinGet Install (Windows)
> [!NOTE] > [!NOTE]
@@ -258,6 +288,6 @@ For Python based inference servers like vllm or tabbyAPI it is recommended to ru
## Star History ## Star History
> [!NOTE] > [!NOTE]
> ⭐️ Star this project to help others discover it! > Thank you to everyone who has given this project a ⭐️!
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date) [![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
+183
View File
@@ -0,0 +1,183 @@
# Improve Testability (#655)
## Current Pain Points
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
## Stages
### Stage 1: YAML-based test config helper
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
Create a test helper in `proxy/helpers_test.go`:
```go
// testConfigFromYAML substitutes simple-responder paths and loads through
// the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
```
Tests would then look like:
```go
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
model2:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
`)
proxy := New(config)
// ... same assertions
}
```
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
**2a. Add testHandler to Process:**
```go
// In Process struct (process.go):
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
```
In `Process.Start()`, skip subprocess + health check when handler is set:
```go
func (p *Process) start() error {
if p.testHandler != nil {
p.setState(StateReady)
return nil
}
// existing subprocess logic...
}
```
In `Process.ProxyRequest()`, delegate directly to the handler:
```go
// Before the reverseProxy.ServeHTTP call:
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
return
}
```
**2b. Test helper to create the handler:**
```go
// newTestHandler returns an http.Handler that mimics llama.cpp's API
// (same endpoints as simple-responder).
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
// ... other endpoints
return mux
}
```
Tests for routing/auth/CORS/streaming then become:
```go
func TestProxyManager_AuthRequired(t *testing.T) {
handler := newTestHandler("model1")
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
requiredAPIKeys: [test-key]
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
`)
pm := NewProxyManager(config)
// inject handler — skips subprocess, health check, port allocation
pm.processGroups["model1"].process.testHandler = handler
}
```
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
### Stage 3: Migrate tests incrementally
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
Priority order:
1. `proxymanager_test.go` routing tests (highest count, most repetition)
2. `peerproxy_test.go` (straightforward, all HTTP routing)
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
Tests that **must keep simple-responder:**
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
**Scope:** ~60-70% of tests can drop simple-responder.
### Stage 4 (optional): Process interface for ProcessGroup
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
```go
type ProcessController interface {
Start() error
Stop(StopStrategy)
ProxyRequest(http.ResponseWriter, *http.Request) error
CurrentState() ProcessState
ID() string
SetState(ProcessState) // for test setup
}
```
This requires:
- Extracting the interface
- A `MockProcess` implementation
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
## Effort/Impact Summary
| Stage | Effort | Impact | Risk |
|-------|--------|--------|------|
| 1. YAML config helper | Low | Config bugs caught earlier | None |
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
| 4. Process interface | High | Pure unit tests possible | Medium |
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
+306
View File
@@ -0,0 +1,306 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/tidwall/gjson"
)
var loremWords = strings.Fields(
"Lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor " +
"incididunt ut labore et dolore magna aliqua Ut enim ad minim veniam quis nostrud " +
"exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat Duis aute " +
"irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla " +
"pariatur Excepteur sint occaecat cupidatat non proident sunt in culpa qui officia " +
"deserunt mollit anim id est laborum Sed ut perspiciatis unde omnis iste natus error " +
"sit voluptatem accusantium doloremque laudantium totam rem aperiam eaque ipsa quae " +
"ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo " +
"Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit",
)
var (
flagListen = flag.String("listen", "localhost:9898", "listen address")
flagTokens = flag.Int("tokens", 1000, "number of tokens to return")
flagTPS = flag.Float64("tps", 75, "tokens per second")
flagLoad = flag.String("load", "0s", "simulated load duration (e.g. 2s, 500ms)")
)
type chunkDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}
type chunkChoice struct {
Index int `json:"index"`
Delta chunkDelta `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type chatChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []chunkChoice `json:"choices"`
}
type completionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type completionChoice struct {
Index int `json:"index"`
Message completionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type completionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type chatCompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []completionChoice `json:"choices"`
Usage completionUsage `json:"usage"`
}
func loremText(n int) string {
words := make([]string, n)
for i := range words {
words[i] = loremWords[i%len(loremWords)]
}
return strings.Join(words, " ")
}
func sendChunk(w http.ResponseWriter, content string, finishReason *string) error {
chunk := chatChunk{
ID: "chatcmpl-fake",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "fake-model",
Choices: []chunkChoice{
{
Index: 0,
Delta: chunkDelta{Content: content},
FinishReason: finishReason,
},
},
}
data, err := json.Marshal(chunk)
if err != nil {
return err
}
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
return err
}
// startLoading runs the countdown log and closes ready when loadDur elapses.
// If loadDur is zero, ready is closed immediately.
func startLoading(loadDur time.Duration) <-chan struct{} {
ready := make(chan struct{})
if loadDur == 0 {
close(ready)
return ready
}
go func() {
deadline := time.Now().Add(loadDur)
log.Printf("loading... %s remaining", loadDur.Round(time.Second))
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
timer := time.NewTimer(loadDur)
for {
select {
case <-timer.C:
close(ready)
log.Printf("ready")
return
case <-ticker.C:
if rem := time.Until(deadline).Round(time.Second); rem > 0 {
log.Printf("loading... %s remaining", rem)
}
}
}
}()
return ready
}
func healthHandler(ready <-chan struct{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
select {
case <-ready:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusServiceUnavailable)
}
}
}
func chatHandler(ready <-chan struct{}) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusBadRequest)
return
}
streaming := gjson.GetBytes(body, "stream").Bool()
ctx := r.Context()
select {
case <-ready:
case <-ctx.Done():
return
}
tokens := *flagTokens
tps := *flagTPS
if tps <= 0 {
tps = 1
}
if !streaming {
delay := time.Duration(float64(tokens) / tps * float64(time.Second))
select {
case <-time.After(delay):
case <-ctx.Done():
return
}
text := loremText(tokens)
resp := chatCompletion{
ID: "chatcmpl-fake",
Object: "chat.completion",
Created: time.Now().Unix(),
Model: "fake-model",
Choices: []completionChoice{
{
Index: 0,
Message: completionMessage{Role: "assistant", Content: text},
FinishReason: "stop",
},
},
Usage: completionUsage{
PromptTokens: 0,
CompletionTokens: tokens,
TotalTokens: tokens,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Send role delta first
first := chatChunk{
ID: "chatcmpl-fake",
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "fake-model",
Choices: []chunkChoice{
{Index: 0, Delta: chunkDelta{Role: "assistant"}},
},
}
if data, err := json.Marshal(first); err == nil {
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
interval := time.Duration(float64(time.Second) / tps)
ticker := time.NewTicker(interval)
defer ticker.Stop()
stop := "stop"
for i := 0; i < tokens; i++ {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
word := loremWords[i%len(loremWords)]
if i < tokens-1 {
if err := sendChunk(w, word+" ", nil); err != nil {
return
}
} else {
if err := sendChunk(w, word, &stop); err != nil {
return
}
}
flusher.Flush()
}
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
}
func main() {
flag.Parse()
loadDur, err := time.ParseDuration(*flagLoad)
if err != nil {
log.Fatalf("invalid -load value %q: %v", *flagLoad, err)
}
ready := startLoading(loadDur)
mux := http.NewServeMux()
mux.HandleFunc("/health", healthHandler(ready))
mux.HandleFunc("/v1/chat/completions", chatHandler(ready))
srv := &http.Server{
Addr: *flagListen,
Handler: mux,
}
go func() {
log.Printf("listening on %s (tokens=%d tps=%.1f load=%s)",
*flagListen, *flagTokens, *flagTPS, loadDur)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("server error: %v", err)
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Printf("shutdown error: %v", err)
}
}
+92
View File
@@ -0,0 +1,92 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/perf"
)
func printSysStat(s perf.SysStat) {
cores := make([]string, len(s.CpuUtilPerCore))
for i, v := range s.CpuUtilPerCore {
cores[i] = fmt.Sprintf("%.1f%%", v)
}
fmt.Printf("[SYS %s]\n", s.Timestamp.Format("15:04:05"))
fmt.Printf(" CPU: %s\n", strings.Join(cores, " "))
fmt.Printf(" Mem: %d MB used / %d MB total (%d MB free)\n", s.MemUsedMB, s.MemTotalMB, s.MemFreeMB)
fmt.Printf(" Swap: %d MB used / %d MB total\n", s.SwapUsedMB, s.SwapTotalMB)
fmt.Printf(" Load: %.2f %.2f %.2f (1m 5m 15m)\n", s.LoadAvg1, s.LoadAvg5, s.LoadAvg15)
}
func printGpuStats(gpus []perf.GpuStat) {
for _, g := range gpus {
fmt.Printf("[GPU %d %s]\n", g.ID, g.Name)
fmt.Printf(" Util: GPU %.1f%% Mem %.1f%%\n", g.GpuUtilPct, g.MemUtilPct)
fmt.Printf(" Mem: %d MB used / %d MB total\n", g.MemUsedMB, g.MemTotalMB)
fmt.Printf(" Temp: %d°C Fan: %.1f%% Power: %.1f W\n", g.TempC, g.FanSpeedPct, g.PowerDrawW)
}
}
func main() {
stream := flag.Bool("stream", false, "stream stats")
interval := flag.Duration("t", time.Second, "polling interval (clamped to 1s1h)")
flag.Parse()
every := *interval
if every < time.Second {
every = time.Second
} else if every > time.Hour {
every = time.Hour
}
l := logmon.New()
l.SetLogLevel(logmon.LevelDebug)
s, err := perf.ReadSysStats()
if err != nil && err != perf.ErrNotImplemented {
fmt.Println("Sys Error:", err)
return
}
printSysStat(s)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
gpuCh, err := perf.GetGpuStats(ctx, every, l)
if err != nil && !errors.Is(err, perf.ErrNotImplemented) && !errors.Is(err, perf.ErrNoGpuTool) {
fmt.Println("GPU Init Error:", err)
return
}
if gpuCh != nil {
select {
case g := <-gpuCh:
printGpuStats(g)
case <-ctx.Done():
fmt.Println("GPU: timed out waiting for stats")
}
}
if *stream {
m, _ := perf.New(config.PerformanceConfig{Every: every}, l)
m.Start()
defer m.Stop()
sysCh, gpuCh, unsub := m.Subscribe()
defer unsub()
for {
select {
case s := <-sysCh:
printSysStat(s)
case g := <-gpuCh:
printGpuStats(g)
}
}
}
}
+96
View File
@@ -0,0 +1,96 @@
package main
import (
"flag"
"fmt"
"os"
"sync"
"time"
tea "github.com/charmbracelet/bubbletea"
)
func main() {
prompt := flag.String("prompt", "Write a few sentences about the history of computing.", "user message sent to each model")
maxTokens := flag.Int("max-tokens", 256, "max_tokens per request")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <base-url> <model> [model...]\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Example: %s -max-tokens 400 http://localhost:8080 A B C D\n\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
args := flag.Args()
if len(args) < 2 {
flag.Usage()
os.Exit(1)
}
baseURL := args[0]
models := args[1:]
m := newModel(models)
prog := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
// Chain of triggers ensures requests are sent in the order provided.
triggers := make([]chan struct{}, len(models))
for i := range triggers {
triggers[i] = make(chan struct{}, 1)
}
triggers[0] <- struct{}{}
var wg sync.WaitGroup
start := time.Now()
for i, name := range models {
wg.Add(1)
go func(idx int, mdl string) {
defer wg.Done()
<-triggers[idx]
reqStart := time.Now()
prog.Send(statusMsg{idx: idx, status: statusStreaming})
if idx+1 < len(triggers) {
triggers[idx+1] <- struct{}{}
}
err := sendRequest(baseURL, mdl, *prompt, *maxTokens, idx, func(i int, text string) {
prog.Send(deltaMsg{idx: i, text: text})
})
elapsed := time.Since(reqStart)
if err != nil {
prog.Send(statusMsg{idx: idx, status: statusError, elapsed: elapsed, err: err})
} else {
prog.Send(statusMsg{idx: idx, status: statusDone, elapsed: elapsed})
}
}(i, name)
}
if _, err := prog.Run(); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
wg.Wait()
printSummary(m, start)
}
func printSummary(m *model, start time.Time) {
fmt.Println("Summary:")
for _, p := range m.panels {
switch p.status {
case statusError:
fmt.Printf(" [%d] %-20s ERROR elapsed=%s err=%v\n",
p.idx, p.model, p.elapsed.Round(time.Millisecond), p.err)
case statusDone:
fmt.Printf(" [%d] %-20s done elapsed=%s\n",
p.idx, p.model, p.elapsed.Round(time.Millisecond))
default:
fmt.Printf(" [%d] %-20s %s\n", p.idx, p.model, p.status)
}
}
fmt.Printf("all done in %s\n", time.Since(start).Round(time.Millisecond))
}
+88
View File
@@ -0,0 +1,88 @@
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
// deltaSink receives streamed text fragments for a given model panel.
type deltaSink func(idx int, text string)
type streamDelta struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
}
type streamChoice struct {
Delta streamDelta `json:"delta"`
}
type streamChunk struct {
Choices []streamChoice `json:"choices"`
}
// sendRequest streams a chat completion and forwards each content/reasoning
// delta to sink. Reasoning and assistant content are emitted into the same
// stream so they render together.
func sendRequest(baseURL, model, prompt string, maxTokens, idx int, sink deltaSink) error {
payload := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"max_tokens": maxTokens,
"stream": true,
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
resp, err := http.Post(baseURL+"/v1/chat/completions", "application/json", bytes.NewReader(body))
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
if data == "[DONE]" {
break
}
continue
}
var chunk streamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, c := range chunk.Choices {
if c.Delta.ReasoningContent != "" {
sink(idx, c.Delta.ReasoningContent)
}
if c.Delta.Content != "" {
sink(idx, c.Delta.Content)
}
}
}
return scanner.Err()
}
+343
View File
@@ -0,0 +1,343 @@
package main
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type panelStatus int
const (
statusWaiting panelStatus = iota
statusStreaming
statusDone
statusError
)
func (s panelStatus) String() string {
switch s {
case statusStreaming:
return "streaming"
case statusDone:
return "done"
case statusError:
return "error"
default:
return "waiting"
}
}
// deltaMsg appends streamed text to a panel.
type deltaMsg struct {
idx int
text string
}
// statusMsg updates a panel's lifecycle state.
type statusMsg struct {
idx int
status panelStatus
elapsed time.Duration
err error
}
type panel struct {
idx int
model string
color lipgloss.Color
status panelStatus
buf strings.Builder
elapsed time.Duration
err error
}
const (
minPanelWidth = 28
maxCols = 3
panelHeight = 9 // total box height including border + header
)
type model struct {
panels []*panel
focused int
vp viewport.Model
width int
height int
cols int
pw int // inner panel content width
ready bool
}
func newModel(models []string) *model {
// Assign a stable color per unique model name (by first appearance).
colorOf := map[string]lipgloss.Color{}
panels := make([]*panel, len(models))
for i, m := range models {
c, ok := colorOf[m]
if !ok {
c = modelPalette[len(colorOf)%len(modelPalette)]
colorOf[m] = c
}
panels[i] = &panel{idx: i, model: m, color: c, status: statusWaiting}
}
return &model{panels: panels, focused: 0}
}
func (m *model) Init() tea.Cmd { return nil }
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.relayout()
m.refreshViewport(true)
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "q", "ctrl+c", "esc":
return m, tea.Quit
case "tab", "right", "l":
m.setFocus(m.focused + 1)
return m, nil
case "shift+tab", "left", "h":
m.setFocus(m.focused - 1)
return m, nil
}
var cmd tea.Cmd
m.vp, cmd = m.vp.Update(msg)
return m, cmd
case tea.MouseMsg:
if msg.Action == tea.MouseActionPress && msg.Button == tea.MouseButtonLeft {
if idx, ok := m.panelAt(msg.X, msg.Y); ok {
m.setFocus(idx)
}
return m, nil
}
var cmd tea.Cmd
m.vp, cmd = m.vp.Update(msg)
return m, cmd
case deltaMsg:
p := m.panels[msg.idx]
p.buf.WriteString(msg.text)
if msg.idx == m.focused {
atBottom := m.vp.AtBottom()
m.refreshViewport(false)
if atBottom {
m.vp.GotoBottom()
}
}
return m, nil
case statusMsg:
p := m.panels[msg.idx]
p.status = msg.status
p.elapsed = msg.elapsed
p.err = msg.err
if msg.err != nil {
errTxt := lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Render("\n" + msg.err.Error())
p.buf.WriteString(errTxt)
if msg.idx == m.focused {
m.refreshViewport(false)
m.vp.GotoBottom()
}
}
return m, nil
}
return m, nil
}
func (m *model) setFocus(idx int) {
if len(m.panels) == 0 {
return
}
if idx < 0 {
idx = len(m.panels) - 1
}
if idx >= len(m.panels) {
idx = 0
}
if idx == m.focused {
return
}
m.focused = idx
m.refreshViewport(true)
}
// relayout recomputes grid columns and panel/viewport dimensions.
func (m *model) relayout() {
if m.width < minPanelWidth+4 {
m.cols = 1
} else {
m.cols = m.width / (minPanelWidth + 2)
if m.cols > maxCols {
m.cols = maxCols
}
if m.cols > len(m.panels) {
m.cols = len(m.panels)
}
if m.cols < 1 {
m.cols = 1
}
}
// inner content width: total width / cols, minus borders+padding (4) and gap.
boxOuter := m.width/m.cols - 1
m.pw = boxOuter - 4
if m.pw < 8 {
m.pw = 8
}
m.vp = viewport.New(m.pw, panelHeight-2)
m.ready = true
}
func (m *model) refreshViewport(reset bool) {
if !m.ready || len(m.panels) == 0 {
return
}
content := lipgloss.NewStyle().Width(m.pw).Render(m.panels[m.focused].buf.String())
m.vp.SetContent(content)
if reset {
m.vp.GotoBottom()
}
}
// panelAt maps screen coordinates to a panel index based on the grid layout.
func (m *model) panelAt(x, y int) (int, bool) {
if m.cols == 0 {
return 0, false
}
boxOuterW := m.width/m.cols + 1
col := x / boxOuterW
row := y / panelHeight
idx := row*m.cols + col
if col < m.cols && idx >= 0 && idx < len(m.panels) {
return idx, true
}
return 0, false
}
func (m *model) View() string {
if !m.ready {
return "loading..."
}
rows := []string{}
var current []string
for i, p := range m.panels {
current = append(current, m.renderPanel(p, i == m.focused))
if len(current) == m.cols {
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
current = nil
}
}
if len(current) > 0 {
rows = append(rows, lipgloss.JoinHorizontal(lipgloss.Top, current...))
}
grid := lipgloss.JoinVertical(lipgloss.Left, rows...)
footer := lipgloss.NewStyle().Faint(true).Render(
"tab/click: focus panel • wheel/↑↓/pgup/pgdn: scroll focused • q: quit")
return grid + "\n" + footer
}
// modelPalette gives each panel a distinct, readable color for its name.
var modelPalette = []lipgloss.Color{
"39", // blue
"213", // magenta
"214", // orange
"45", // cyan
"141", // purple
"203", // salmon
"82", // lime
"227", // light yellow
}
func statusColor(s panelStatus) lipgloss.Color {
switch s {
case statusStreaming:
return lipgloss.Color("220") // yellow - active
case statusDone:
return lipgloss.Color("42") // green - success
case statusError:
return lipgloss.Color("196") // red - error
default:
return lipgloss.Color("244") // gray - waiting
}
}
func (m *model) renderPanel(p *panel, focused bool) string {
border := lipgloss.RoundedBorder()
if focused {
border = lipgloss.DoubleBorder()
}
style := lipgloss.NewStyle().
Border(border).
BorderForeground(lipgloss.Color("240"))
statusTxt := p.status.String()
if p.elapsed > 0 {
statusTxt += " " + p.elapsed.Round(time.Millisecond).String()
}
// Header: model name (left, model color) + status/timer (right, status color).
name := fmt.Sprintf("[%d] %s", p.idx, p.model)
gap := m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
if gap < 1 {
name = truncate(name, m.pw-lipgloss.Width(statusTxt)-1)
gap = m.pw - lipgloss.Width(name) - lipgloss.Width(statusTxt)
}
if gap < 1 {
gap = 1
}
header := lipgloss.NewStyle().Bold(true).Foreground(p.color).Render(name) +
strings.Repeat(" ", gap) +
lipgloss.NewStyle().Foreground(statusColor(p.status)).Render(statusTxt)
var bodyLines string
if focused {
bodyLines = m.vp.View()
} else {
bodyLines = tailLines(p.buf.String(), m.pw, panelHeight-2)
}
content := lipgloss.JoinVertical(lipgloss.Left, header, bodyLines)
return style.Width(m.pw).Height(panelHeight - 2).Render(content)
}
func truncate(s string, w int) string {
if w <= 0 {
return ""
}
if lipgloss.Width(s) <= w {
return s
}
r := []rune(s)
if len(r) > w {
r = r[:w]
}
return string(r)
}
// tailLines wraps text to width w and returns the last n lines.
func tailLines(s string, w, n int) string {
wrapped := lipgloss.NewStyle().Width(w).Render(s)
lines := strings.Split(wrapped, "\n")
if len(lines) > n {
lines = lines[len(lines)-n:]
}
for len(lines) < n {
lines = append(lines, "")
}
return strings.Join(lines, "\n")
}
+227 -72
View File
@@ -82,6 +82,78 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"description": "Timeout settings for proxy connections." "description": "Timeout settings for proxy connections."
},
"groupsConfig": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"members"
],
"properties": {
"swap": {
"type": "boolean",
"default": true,
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
},
"exclusive": {
"type": "boolean",
"default": true,
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
},
"persistent": {
"type": "boolean",
"default": false,
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
},
"members": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
}
}
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
},
"matrixConfig": {
"type": "object",
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
"required": [
"vars",
"sets"
],
"properties": {
"vars": {
"type": "object",
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
},
"propertyNames": {
"pattern": "^[a-zA-Z0-9]{1,8}$"
}
},
"evict_costs": {
"type": "object",
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
"additionalProperties": {
"type": "integer",
"minimum": 1
}
},
"sets": {
"type": "object",
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
}
}
},
"additionalProperties": false
} }
}, },
"properties": { "properties": {
@@ -142,6 +214,25 @@
"default": 5, "default": 5,
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures." "description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
}, },
"performance": {
"type": "object",
"properties": {
"disabled": {
"type": "boolean",
"default": false,
"description": "Disable system performance monitoring."
},
"every": {
"type": "string",
"pattern": "^[-+]?(\\d+(\\.\\d+)?(ns|us|ms|s|m|h))+$",
"default": "15s",
"description": "Delay between polling for new performance statistics. Minimum duration is 1s. Lower values use more RAM as stats are kept in memory."
}
},
"additionalProperties": false,
"default": {},
"description": "Configuration for CPU, RAM and GPU monitoring statistics."
},
"startPort": { "startPort": {
"type": "integer", "type": "integer",
"default": 5800, "default": 5800,
@@ -287,81 +378,68 @@
}, },
"timeouts": { "timeouts": {
"$ref": "#/definitions/timeouts" "$ref": "#/definitions/timeouts"
},
"capabilities": {
"type": "object",
"properties": {
"in": {
"type": "array",
"minItems": 1,
"uniqueItems": true,
"default": [],
"items": {
"type": "string",
"enum": [
"text",
"audio",
"image"
]
},
"description": "List of input modalities understood by the model."
},
"out": {
"type": "array",
"minItems": 1,
"uniqueItems": true,
"default": [],
"items": {
"type": "string",
"enum": [
"text",
"audio",
"image"
]
},
"description": "List of output modalities generated by the model."
},
"tools": {
"type": "boolean",
"default": false,
"description": "Whether the model supports function calling."
},
"reranker": {
"type": "boolean",
"default": false,
"description": "Whether the model supports the /v1/rerank endpoint."
},
"context": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Maximum token context length supported by the model."
}
},
"additionalProperties": false,
"description": "Defines what the model accepts for input, output and other metadata. Used in v1/models to inform clients what the model can do. An empty capabilities block (all zero values) is treated as not configured."
} }
} }
} }
}, },
"groups": { "groups": {
"type": "object", "$ref": "#/definitions/groupsConfig"
"additionalProperties": {
"type": "object",
"required": [
"members"
],
"properties": {
"swap": {
"type": "boolean",
"default": true,
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
},
"exclusive": {
"type": "boolean",
"default": true,
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
},
"persistent": {
"type": "boolean",
"default": false,
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
},
"members": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
}
}
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
}, },
"matrix": { "matrix": {
"type": "object", "$ref": "#/definitions/matrixConfig"
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
"required": [
"vars",
"sets"
],
"properties": {
"vars": {
"type": "object",
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
},
"propertyNames": {
"pattern": "^[a-zA-Z0-9]{1,8}$"
}
},
"evict_costs": {
"type": "object",
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
"additionalProperties": {
"type": "integer",
"minimum": 1
}
},
"sets": {
"type": "object",
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
}
}
},
"additionalProperties": false
}, },
"hooks": { "hooks": {
"type": "object", "type": "object",
@@ -493,26 +571,103 @@
}, },
"default": {}, "default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap." "description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
},
"routing": {
"type": "object",
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
"properties": {
"scheduler": {
"type": "object",
"description": "Scheduler configuration. Decides the order in which queued requests are serviced.",
"properties": {
"use": {
"type": "string",
"enum": [
"fifo"
],
"default": "fifo",
"description": "Scheduler to use. Only 'fifo' is currently supported."
},
"settings": {
"type": "object",
"properties": {
"fifo": {
"type": "object",
"properties": {
"priority": {
"type": "object",
"description": "Per-model priority. Keys are model IDs, values are integers (default 0). Higher values are serviced first.",
"additionalProperties": {
"type": "integer"
}
}
},
"additionalProperties": false
}
},
"additionalProperties": false
}
},
"additionalProperties": false
},
"router": {
"type": "object",
"description": "Router configuration. Selects between the group and matrix swapping strategies.",
"properties": {
"use": {
"type": "string",
"enum": [
"group",
"matrix"
],
"default": "group",
"description": "Router to use. 'group' uses static groups, 'matrix' uses the solver-based swap matrix."
},
"settings": {
"type": "object",
"properties": {
"groups": {
"$ref": "#/definitions/groupsConfig"
},
"matrix": {
"$ref": "#/definitions/matrixConfig"
}
},
"additionalProperties": false
}
},
"additionalProperties": false
}
},
"additionalProperties": false
} }
}, },
"allOf": [ "allOf": [
{ {
"if": { "if": {
"required": ["groups"] "required": [
"groups"
]
}, },
"then": { "then": {
"not": { "not": {
"required": ["matrix"] "required": [
"matrix"
]
} }
} }
}, },
{ {
"if": { "if": {
"required": ["matrix"] "required": [
"matrix"
]
}, },
"then": { "then": {
"not": { "not": {
"required": ["groups"] "required": [
"groups"
]
} }
} }
} }
+216 -81
View File
@@ -55,6 +55,18 @@ metricsMaxInMemory: 1000
# - set to 0 to disable # - set to 0 to disable
captureBuffer: 15 captureBuffer: 15
# performance: configuration for system monitoring statistics
# - timing values are duration strings like 1s, 1h30m, 90m, 2h10s, etc.
performance:
# disabled: boolean
# - default: false
disabled: false
# every: delay between polling for new performance statistics
# - default: 5s
# - minimum duration 5s
every: 15s
# startPort: sets the starting port number for the automatic ${PORT} macro. # startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800 # - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings # - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -96,8 +108,7 @@ globalTTL: 0
macros: macros:
# Example of a multi-line macro # Example of a multi-line macro
"latest-llama": > "latest-llama": >
/path/to/llama-server/llama-server-ec9e0301 /path/to/llama-server/llama-server-ec9e0301 --port ${PORT}
--port ${PORT}
"default_ctx": 4096 "default_ctx": 4096
@@ -257,7 +268,8 @@ models:
# the ${temp} macro will remain a float # the ${temp} macro will remain a float
temperature: ${temp} temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}" note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp},
context=${default_ctx}"
a_list: a_list:
- 1 - 1
@@ -300,6 +312,37 @@ models:
tlsHandshake: 10 tlsHandshake: 10
idleConn: 90 idleConn: 90
# capabilities: defines what the model accepts for input, output and other metadata
# - optional; omitted or all-zero means no capabilities
# - used in v1/models to inform clients what the model can do
capabilities:
# in: list of modalities understood by the model
# - default: []
# - valid: text, audio, image
in:
- text
- audio
- image
# out: list of modalities generated by the model
# - default: []
# - valid: text, audio, image
out:
- text
- audio
- image
# tools: the model supports function calling
# - default: false
tools: true
# reranker: the model supports the /v1/rerank endpoint
# - default: false
reranker: false
# context: the maximum token context length supported
# - default: 0
# - must be an integer > 0
context: 32000
# Unlisted model example: # Unlisted model example:
"qwen-unlisted": "qwen-unlisted":
# unlisted: boolean, true or false # unlisted: boolean, true or false
@@ -331,84 +374,6 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted # - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID} cmdStop: docker stop ${MODEL_ID}
# =============================================================================
# matrix: run concurrent models with a solver-based swap DSL
# =============================================================================
#
# Note:
# A config must use either a matrix or legacy groups, not both. A configuration error
# will occur if both are defined. Configuration examples for legacy Groups can be found:
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
#
# The matrix declares valid combinations of models that can run concurrently.
# When a model is requested, the solver finds the cheapest way to make it
# available by evicting as few (and least costly) running models as possible.
#
# Solver behavior:
# 1. Request arrives for model X
# 2. If X is already running, forward immediately. Done.
# 3. Find all sets containing X
# 4. For each candidate set, compute cost: sum of evict_costs for
# every running model NOT in that set
# 5. Pick lowest cost candidate. Ties broken by definition order.
# 6. Evict what needs to stop. Start X. Forward request.
#
# Subset semantics: a set [a, b, c] means any subset is valid.
# Only the requested model is started — others are not preloaded.
#
# A model not appearing in any set can only run alone.
#
matrix:
# vars: short names for models (alphanumeric, 1-8 chars)
# - required for sets and evict_costs settings
# - each entry is a short name to a real model ID. Do not use an alias
# - used to keep set DSL logic short and easier to read
# - sets and evict_costs only use identifiers defined in vars
vars:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# evict_costs: relative cost of losing a running model (default: 1)
evict_costs:
v: 50 # vllm backend, slow cold start
L: 30 # 70B weights, slow to load
# sets: named sets of concurrent model combinations
# Values are DSL strings with operators:
# & AND (models run together)
# | OR (alternatives)
# () grouping
# +ref inline another set's expression
#
# Expansion examples:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → expands llms inline, then applies & v
sets:
# LLM + TTS: switching between g/q/m won't evict v
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# LLM + TTS + reranker
# expands to: [g,v,e], [q,v,e]
with_rerank: "(g | q) & v & e"
# LLM + image generation, no TTS
# expands to: [g,sd], [q,sd]
creative: "(g | q) & sd"
# 70B model uses all GPUs, can only run alone
# expands to: [L]
full: "L"
# hooks: a dictionary of event triggers and actions # hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - the only supported hook is on_startup # - the only supported hook is on_startup
@@ -425,6 +390,176 @@ hooks:
preload: preload:
- "llama" - "llama"
# routing:
# Controls how llama-swap decides which models can run at the same time and
# which get swapped out. Choose one of two swap engines:
#
# - group: the default engine. Simpler to configure. You define groups of
# models that run together, and loading one group typically unloads
# the others.
#
# - matrix: the newer engine. More involved to configure, but far more
# flexible. It uses a small expression language to describe which
# model combinations are allowed to run concurrently, enabling
# setups that groups cannot express.
#
# The routing section is optional.
routing:
router:
# use: a string defining which engine to use
# - optional, default: "group"
# - valid values: group, matrix
use: group
# settings: a dictionary of settings for the specific engines
settings:
# groups: a dictionary of named groups
# - optional, default: empty dictionary
# - lets you keep some models loaded while others swap out
# - every member must be a model ID defined in the models section
# - a model can belong to only one group
# - behaviour is set per group with the `swap`, `exclusive` and
# `persistent` fields
# - see issue #109 for details
#
# NOTE: the model names below are illustrative and are not defined above.
groups:
# group1 reproduces llama-swap's default behaviour: only one model
# runs at a time across the entire instance.
"group1":
# swap: how members of this group swap among themselves
# - optional, default: true
# - true: only one member runs at a time
# - false: all members can run together, no swapping
swap: true
# exclusive: how this group affects other groups
# - optional, default: true
# - true: running a member unloads every other group
# - false: running a member leaves other groups untouched
exclusive: true
# members: the model IDs in this group
# required
members:
- "llama"
- "qwen-unlisted"
# group2: members all run together, but loading any other group
# unloads them.
"group2":
# swap: false lets all members stay loaded at once
swap: false
# exclusive: false means requesting a member loads it without
# unloading any other group
exclusive: false
members:
- "docker-llama"
- "modelA"
- "modelB"
# forever: a persistent group that other groups can never unload.
"forever":
# persistent: other groups cannot unload this group's members
# - optional, default: false
# - has no effect on swapping within the group
persistent: true
# swap/exclusive: false keeps all members loaded and avoids
# unloading other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
# The matrix lists the model combinations that are allowed to run
# concurrently. When a model is requested, the solver makes room for it
# by evicting as few running models as possible, preferring to keep the
# costliest ones loaded.
#
# Solver behaviour:
# 1. A request arrives for model X.
# 2. If X is already running, forward the request. Done.
# 3. Collect every set that contains X.
# 4. For each set, add up the evict_costs of the running models that
# are NOT in that set — that is the set's cost.
# 5. Choose the lowest-cost set. Break ties by definition order.
# 6. Evict the models outside that set, start X, forward the request.
#
# Subset semantics: a set [a, b, c] also permits any subset of itself.
# Only the requested model is started; the others are not preloaded.
#
# A model that appears in no set can only run on its own.
#
matrix:
# vars: short aliases for model IDs (alphanumeric, 1-8 chars)
# - required: sets and evict_costs reference these names, not model IDs
# - map each short name to a real model ID (not a model alias)
# - keeps the set expressions short and readable
vars:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# evict_costs: relative cost of losing a running model (default: 1)
evict_costs:
v: 50 # vllm backend, slow cold start
L: 30 # 70B weights, slow to load
# sets: named combinations of models that may run together.
# Each value is an expression built from these operators:
# & AND (models run together)
# | OR (alternatives)
# () grouping
# +ref inline the expression of another set
#
# Each expression expands into one or more concrete sets:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → inline the llms set, then AND with v
sets:
# An LLM plus TTS. Switching between g/q/m keeps v loaded.
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# An LLM plus TTS plus reranker.
# expands to: [g,v,e], [q,v,e]
with_rerank: "(g | q) & v & e"
# An LLM plus image generation, no TTS.
# expands to: [g,sd], [q,sd]
creative: "(g | q) & sd"
# The 70B model uses every GPU, so it can only run alone.
# expands to: [L]
full: "L"
# scheduler: how queued requests are ordered.
# The default and only valid scheduler is "fifo"
scheduler:
use: fifo
settings:
fifo:
# priority: a dictionary of model ID -> priority
# - optional, default: empty dictionary
# - models default to priority 0
# - higher priority requests are serviced first in the queue
priority:
A: 10
B: 5
C: 5
D: 1
# peers: a dictionary of remote peers and models they provide # peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary # - optional, default empty dictionary
# - peers can be another llama-swap # - peers can be another llama-swap
+59 -9
View File
@@ -46,13 +46,31 @@ fi
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp} BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp} SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable # LS_REPO is the destination of the built container image — defaults to the
# to enable easy container builds on forked repos # current GitHub repository so forked CI builds publish to the fork's own
# ghcr.io namespace without code changes.
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap} LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
# LS_BINARY_REPO is where the llama-swap release tarball is downloaded
# from. Decoupled from LS_REPO so forks (which usually have no releases of
# their own) can still build a container by pulling the canonical binary
# from upstream. Override via the LS_BINARY_REPO env var when you maintain
# fork-side releases.
LS_BINARY_REPO=${LS_BINARY_REPO:-mostlygeek/llama-swap}
# the most recent llama-swap tag # the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming # have to strip out the 'v' due to .tar.gz file naming.
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//') # Authenticated request — unauth'd github.com API is 60/hr per IP and GHA
# runners share IPs, so the call regularly returns rate-limit JSON and
# `.tag_name` then resolves to "null", producing a bogus `vnull` URL below.
LS_VER=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/repos/${LS_BINARY_REPO}/releases/latest" \
| jq -r .tag_name | sed 's/v//')
if [[ -z "$LS_VER" || "$LS_VER" == "null" ]]; then
log_info "Error: could not resolve latest llama-swap release tag from ${LS_BINARY_REPO}"
exit 1
fi
# Fetches the most recent llama.cpp tag matching the given prefix # Fetches the most recent llama.cpp tag matching the given prefix
# Handles pagination to search beyond the first 100 results # Handles pagination to search beyond the first 100 results
@@ -126,6 +144,25 @@ if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
exit 0 exit 0
fi fi
# cpu is the only backend with a multi-arch upstream base
# (ghcr.io/ggml-org/llama.cpp:server-bXXXX ships amd64+arm64); GPU backends
# are amd64-only and stay on the original `docker build` path so the
# sd-server layer can still FROM the just-built image via the local
# dockerd image store (buildx's container driver has a separate store
# that doesn't share with dockerd, which breaks the sd build).
if [ "$ARCH" == "cpu" ]; then
if [ "$PUSH_IMAGES" == "true" ]; then
BUILDX_FLAGS="--push --platform linux/amd64,linux/arm64"
else
# Smoke build: validate both platforms but emit no output. buildx
# on the docker-container driver defaults to cacheonly when
# neither --push nor --load is given, so each arch fully builds
# and a regression in either fails CI — without materializing the
# image or needing to --load (which is multi-arch-incompatible).
BUILDX_FLAGS="--platform linux/amd64,linux/arm64"
fi
fi
for CONTAINER_TYPE in non-root root; do for CONTAINER_TYPE in non-root root; do
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}" CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}" CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
@@ -142,11 +179,23 @@ for CONTAINER_TYPE in non-root root; do
fi fi
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER" log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \ if [ "$ARCH" == "cpu" ]; then
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \ docker buildx build $BUILDX_FLAGS --provenance=false \
--build-arg BASE_IMAGE=${BASE_IMAGE} . -f llama-swap.Containerfile \
--build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_BINARY_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} \
--build-arg BASE_IMAGE=${BASE_IMAGE} \
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
else
docker build --provenance=false -f llama-swap.Containerfile \
--build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_BINARY_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} \
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
fi
# For architectures with stable-diffusion.cpp support, layer sd-server on top # For architectures with stable-diffusion.cpp support, layer sd-server on top.
# Stays on `docker build` so the base resolves from local dockerd.
case "$ARCH" in case "$ARCH" in
"musa" | "vulkan") "musa" | "vulkan")
log_info "Adding sd-server to $CONTAINER_TAG" log_info "Adding sd-server to $CONTAINER_TAG"
@@ -157,7 +206,8 @@ for CONTAINER_TYPE in non-root root; do
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;; -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
esac esac
if [ "$PUSH_IMAGES" == "true" ]; then # cpu builds push inline via buildx --push; all other archs push here.
if [ "$ARCH" != "cpu" ] && [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG} docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST} docker push ${CONTAINER_LATEST}
fi fi
+9 -4
View File
@@ -2,7 +2,6 @@ ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
ARG BASE_TAG=server-cuda ARG BASE_TAG=server-cuda
FROM ${BASE_IMAGE}:${BASE_TAG} FROM ${BASE_IMAGE}:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=170 ARG LS_VER=170
ARG LS_REPO=mostlygeek/llama-swap ARG LS_REPO=mostlygeek/llama-swap
@@ -34,9 +33,15 @@ WORKDIR /app
ENV PATH="/app:${PATH}" ENV PATH="/app:${PATH}"
RUN \ RUN \
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \ set -eux; \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \ case "$(uname -m)" in \
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz" x86_64) ARCH=amd64 ;; \
aarch64) ARCH=arm64 ;; \
*) echo "unsupported arch: $(uname -m)" >&2; exit 1 ;; \
esac; \
curl --fail -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_${ARCH}.tar.gz"
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
+5 -1
View File
@@ -42,6 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \ build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget software-properties-common \ curl ca-certificates ccache make wget software-properties-common \
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \ libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
spirv-headers \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
WORKDIR /build WORKDIR /build
@@ -148,7 +149,7 @@ ARG IK_LLAMA_COMMIT_HASH=unknown
ARG RUN_UID=0 ARG RUN_UID=0
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-numpy python3-sentencepiece \ python3-numpy python3-sentencepiece python3-pip \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Create non-root user when RUN_UID != 0 # Create non-root user when RUN_UID != 0
@@ -179,6 +180,9 @@ COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
# Copy ik-llama-server (CUDA only; empty copy for vulkan) # Copy ik-llama-server (CUDA only; empty copy for vulkan)
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/ COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
# Install uv
RUN pip install uv --break-system-packages
# Copy llama-swap binary # Copy llama-swap binary
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/ COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
COPY --from=llama-swap-download /install/llama-swap-version /tmp/ COPY --from=llama-swap-download /install/llama-swap-version /tmp/
+10 -2
View File
@@ -38,8 +38,16 @@ if [ "$VERSION" = "latest" ]; then
echo "Latest version: ${VERSION}" echo "Latest version: ${VERSION}"
fi fi
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
aarch64|arm64) ARCH="arm64" ;;
*) echo "FATAL: Unsupported architecture: $ARCH" >&2; exit 1 ;;
esac
# Download and extract # Download and extract
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz" URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_${ARCH}.tar.gz"
echo "=== Downloading llama-swap v${VERSION} ===" echo "=== Downloading llama-swap v${VERSION} ==="
echo "URL: $URL" echo "URL: $URL"
curl -fSL -o /tmp/llama-swap.tar.gz "$URL" curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
@@ -56,4 +64,4 @@ fi
echo "$VERSION" > /install/llama-swap-version echo "$VERSION" > /install/llama-swap-version
echo "=== llama-swap v${VERSION} installed ===" echo "=== llama-swap v${VERSION} installed ==="
ls -la /install/bin/llama-swap ls -la /install/bin/llama-swap
+15 -3
View File
@@ -146,6 +146,18 @@ metricsMaxInMemory: 1000
# - set to 0 to disable # - set to 0 to disable
captureBuffer: 15 captureBuffer: 15
# performance: configuration for system monitoring statistics
# - timing values are duration strings like 1s, 1h30m, 90m, 2h10s, etc.
performance:
# disabled: boolean
# - default: false
enable: true
# every: delay between polling for new performance statistics
# - default: 5s
# - minimum duration 5s
every: 5s
# startPort: sets the starting port number for the automatic ${PORT} macro. # startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800 # - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings # - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -187,8 +199,7 @@ globalTTL: 0
macros: macros:
# Example of a multi-line macro # Example of a multi-line macro
"latest-llama": > "latest-llama": >
/path/to/llama-server/llama-server-ec9e0301 /path/to/llama-server/llama-server-ec9e0301 --port ${PORT}
--port ${PORT}
"default_ctx": 4096 "default_ctx": 4096
@@ -348,7 +359,8 @@ models:
# the ${temp} macro will remain a float # the ${temp} macro will remain a float
temperature: ${temp} temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}" note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp},
context=${default_ctx}"
a_list: a_list:
- 1 - 1
File diff suppressed because it is too large Load Diff
+264
View File
@@ -0,0 +1,264 @@
# New Router Migration TODO
This document tracks the work needed for [cmd/newrouter/main.go](../cmd/newrouter/main.go) and [internal/router/](../internal/router/) to reach feature parity with the legacy entrypoint at [llama-swap.go](../llama-swap.go) plus [proxy/proxymanager.go](../proxy/proxymanager.go).
The work is split into phases so each can land and be tested independently. Earlier phases unblock later ones.
## Current state (newrouter)
`cmd/newrouter` already supports:
- Loading config via `-config`
- Selecting Matrix vs Group router based on config
- Peer routing fallback
- Plain HTTP listen (`-listen`)
- Graceful shutdown on `SIGINT` / `SIGTERM`
- Model extraction from JSON body, query string, and form bodies (see [router.go:88](../internal/router/router.go#L88))
- `Server.ServeHTTP` dispatches a single request to peer or local router based on the requested model
Everything below is missing or only partially implemented.
---
## Phase 1 — Package relocation -- Completed.
Goal: move shared infrastructure packages out from under `proxy/` so the new router does not depend on the legacy proxy tree. This is a prerequisite for retiring `proxy/` in Phase 8.
---
## Phase 2 — Server lifecycle parity -- Completed.
Goal: make `cmd/newrouter` a drop-in replacement for the legacy binary's process model, _without_ yet adding any extra HTTP endpoints.
---
## Phase 3 — `internal/chain` package -- Completed.
API: `chain.New(mws...).Then(final)` for ServeMux registration; `Append` returns an extended Chain without mutating the receiver, so a base stack (auth/CORS) can be reused across many routes with per-route additions.
---
## Phase 4 — `internal/server` package scaffolding (ProxyManager replacement) -- Completed.
Goal: build the [internal/server](../internal/server/) package so it can stand in for [proxy.ProxyManager](../proxy/proxymanager.go#L67) — the mux, lifecycle, model dispatch, custom endpoints, request filters, auth/CORS, and upstream passthrough. After this phase, `cmd/newrouter/main.go` constructs a `server.Server` instead of a bare `router.Server`.
The legacy `ProxyManager` collapses three concerns into one struct: the HTTP mux, the model→process router, and the cross-cutting services (loggers, metrics, perf, inflight counter, version). The new layout keeps the `router.Router` implementations focused on model dispatch and lets `internal/server.Server` own the mux and all cross-cutting middleware. `server.Server` builds the `local` and `peer` routers directly and dispatches between them itself, so it fully **supersedes `internal/router.Server`** — see the cleanup item below.
The phase is split into sub-phases that can land and be tested independently:
| Sub-phase | Scope |
| --------- | -------------------------------------------------------------------------- |
| 4a | package scaffolding — struct, `New`, `ServeHTTP`, `Shutdown`, model routes |
| 4b | custom (non-model-dispatched) HTTP endpoints |
| 4c | request-body filter middleware |
| 4d | auth & CORS middleware |
| 4e | upstream passthrough |
The package is split by concern across stub files already in place:
| File | Responsibility | Filled in by |
| ------------ | ----------------------------------------------- | ---------------------- |
| `server.go` | `Server` struct, `New`, `ServeHTTP`, `Shutdown` | 4a |
| `log.go` | `muxlog` combined logger; `/logs` handlers | 4a |
| `auth.go` | `CreateAuthMiddleware` | 4d |
| `filters.go` | request-body filter middleware | 4c |
| `api.go` | llama-swap-specific API handlers | 4b / Phase 5 / Phase 6 |
| `ui.go` | embedded UI serving | Phase 7 |
### Phase 4a — package scaffolding -- Completed.
`server.Server` owns the mux, the `local`/`peer` routers, `muxlog`, and a
shutdown context. `New` builds the routers, registers all model-dispatched
routes on a stdlib `http.ServeMux`, and wraps the mux with the global CORS
middleware. `localPeerHandler` resolves the model once via `router.FetchModel`
and dispatches to `local` or `peer`. `Shutdown` stops both routers in parallel
and is idempotent. `cmd/newrouter/main.go` now constructs `server.New(...)`;
`internal/router/server.go` and `server_test.go` were removed as dead code.
### Phase 4b — Custom HTTP endpoints -- Completed.
`GET /v1/models` (local + peer models, aliases, metadata), `GET /health`,
`GET /wol-health`, and `GET /``/ui` are registered. `GET /favicon.ico` is
deferred to Phase 7 since it requires the embedded UI filesystem.
### Phase 4c — Request-body filters -- Completed.
`CreateFilterMiddleware` (in `filters.go`) applies `UseModelName`,
`StripParams`, `SetParams`, and `SetParamsByID` to JSON requests, then
re-attaches the body with `Content-Length` / `Transfer-Encoding` cleanup.
### Phase 4d — Auth & CORS -- Completed.
`CreateAuthMiddleware` validates API keys (Bearer / Basic / `x-api-key`) and
strips the headers before upstream. `CreateCORSMiddleware` answers OPTIONS
preflight; `/v1/models` echoes the `Origin`.
### Phase 4e — Upstream passthrough -- Completed.
`GET /upstream``/ui/models`, and `/upstream/<model>/<path>` proxies to the
resolved model with multi-segment name resolution, canonical-form redirect
(301/308), and prefix stripping.
---
## Phase 5 — Operations endpoints -- Completed.
A new `router.LocalRouter` interface embeds `Router` and adds `RunningModels()`
and `Unload(timeout, models...)`, both implemented once on `baseRouter` so
`Group` and `Matrix` share them — the legacy matrix/group divergence at
[proxymanager.go:1167](../proxy/proxymanager.go#L1167) collapses since
`baseRouter` already unifies process storage. `Peer` does not implement it;
`Server.local` is typed `LocalRouter`, `Server.peer` stays `Router`.
`GET /unload` stops every local process; `GET /running` lists non-stopped
processes joined against config for `cmd`/`proxy`/`ttl`/`name`/`description`.
`startPreload` fires a background `GET /` at each `Hooks.OnStartup.Preload`
model and emits `shared.ModelPreloadedEvent`.
---
## Phase 6 — Metrics, perf, and SSE -- Completed.
`perf.Monitor` is created and started in `cmd/newrouter/main.go` (it outlives
config reloads via `UpdateConfig`) and passed into `server.New`. `GET /metrics`
serves `perf.Monitor.MetricsHandler()` output, 503 when disabled.
`internal/process` emits `shared.ProcessStateChangeEvent` from `setState`.
`server.inflightCounter` (atomic) + `CreateInflightMiddleware` track
model-dispatched requests and emit `InFlightRequestsEvent`. `metricsMonitor`
(in `metrics.go`) parses token usage from upstream responses via
`CreateMetricsMiddleware`.
The `/api` group (API-key protected) is registered: `POST /api/models/unload`,
`POST /api/models/unload/{model...}`, `GET /api/events` (SSE: `modelStatus` /
`logData` / `metrics` / `inflight`), `GET /api/metrics`, `GET /api/performance`
(`?after=` RFC3339 filter), `GET /api/version`. `GET /api/captures/{id}`
returns 501 until 6f.
### Phase 6f — Request/response captures -- Completed.
`proxy/cache` moved to `internal/cache`. `metricsMonitor` stores zstd+CBOR
`ReqRespCapture` records in a sized `cache.Cache` (`captureBuffer` MB, 0
disables). `CreateMetricsMiddleware` buffers request body/headers before
dispatch; `record` builds the capture per a `captureFieldsByPath` table
(`captures.go`) that trims large audio/image payloads, defaulting JSON routes
to `captureAll`. `GET /api/captures/{id}` decompresses and returns the capture;
`getMetrics` resolves `HasCapture` against the cache.
---
## Phase 7 — UI serving -- Completed.
`internal/server/ui.go` embeds `ui_dist` and serves it. `GET /ui/` is
brotli/gzip-aware via `serveCompressedFile`; unknown paths without a file
extension fall back to `index.html` for SPA routing. `GET /favicon.ico` serves
from the same embedded FS. The Makefile `ui` target copies the vite build into
`internal/server/ui_dist`; a committed `placeholder.txt` keeps the embed valid
before a build runs.
---
## Phase 8a - Review Part I
- [x] All functionality from the proxy package has been migrated in the above phases — with the remaining gaps listed in Phase 8b
- [x] Test coverage at or exceeds the level from the proxy package — `internal/server` now at 76.6% vs 73.9% (`proxy`)
### Findings
**Gap 1 — Request logging middleware missing -- Resolved.**
`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records one
access-log line per request to `s.proxylog` in the legacy format
`clientIP "METHOD PATH PROTO" status bodySize "UA" duration`, skipping
`/wol-health`, `/api/performance`, and `/metrics`. A `statusRecorder` captures
the status/body size (forwarding `Flush` for SSE) and `clientIP` honours
`X-Forwarded-For` / `X-Real-IP`. It is wired as the outermost middleware in
`routes()`, wrapping the CORS layer.
**Gap 2 — Per-model log streaming not supported -- Resolved **
`Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) only handles `""`, `"proxy"`, and `"upstream"`. The legacy `ProxyManager.getLogger` ([proxymanager_loghandlers.go:92](../proxy/proxymanager_loghandlers.go#L92)) additionally resolves a model ID against the active process groups / matrix and returns that process's logger. Callers of `GET /logs/stream/<modelID>` will get a 400 instead of the model's live log stream.
**Gap 3 — `UseModelName` not applied to multipart form endpoints -- Resolved.**
`CreateFormFilterMiddleware` ([filters.go](../internal/server/filters.go)) parses
`multipart/form-data` requests, rewrites the `model` field with `UseModelName`,
reconstructs the body via `rewriteMultipartModel`, and re-attaches it with
`Content-Type` / `Content-Length` cleanup. It runs in `modelChain` after the
JSON `filterMW`; each is a no-op for the other's Content-Type. Audio
transcription (`/v1/audio/transcriptions`) and image edit (`/v1/images/edits`)
now honour `use_model_name`.
**Coverage gaps (0 % functions) -- Resolved.**
The functions previously at 0 % (`handleListModels`, `handleMetrics`,
`handleRootRedirect`, `handleUpstreamRedirect`, `handleUpstream`,
`findModelInPath`, `handleAPICapture`, `handleAPIUnloadAll`,
`handleAPIUnloadModel`, `CreateAuthMiddleware`, `extractAPIKey`,
`handleLogStream`, `applyFilters`, `decompressBody`, `filterAcceptEncoding`,
`handleUI`, `handleFavicon`) now have tests across `auth_test.go`, `api_test.go`,
`filters_test.go`, `log_test.go`, and `extras_test.go`.
---
### Phase 8b - Fill gaps discovered in Phase 8a
- [x] **Add request-log middleware**`CreateRequestLogMiddleware` ([log.go](../internal/server/log.go)) records `clientIP "METHOD PATH PROTO" status bodySize "UA" duration` to `s.proxylog`, skips `/wol-health` / `/api/performance` / `/metrics`, and is wired as the outermost middleware in `routes()`.
- [x] **Extend `getLogger` with model-ID resolution** — add a `default:` branch to `Server.getLogger` ([log.go:50](../internal/server/log.go#L50)) that resolves the ID via `s.local` (using a new `LocalRouter.GetProcess(name)` method or equivalent) and returns that process's `Logger()`. Match the fallback behaviour: return a 400 with `"invalid logger. Use 'proxy', 'upstream' or a model's ID"` when not found.
- [x] **`UseModelName` rewrite for multipart endpoints** — `CreateFormFilterMiddleware` parses `multipart/form-data`, rewrites the `model` field according to `UseModelName`, reconstructs the body, and updates `Content-Type` / `Content-Length`. It is wired into `modelChain` after the JSON filter.
- [x] **Raise test coverage to ≥ 74 %**`internal/server` now at 76.1%; tests added for every 0 % function across `auth_test.go`, `api_test.go`, `filters_test.go`, `log_test.go`, and `extras_test.go`.
---
## Phase 8c - Review Part II (entrypoint comparison)
A second pass comparing [cmd/newrouter/main.go](../cmd/newrouter/main.go) against
the legacy [llama-swap.go](../llama-swap.go) + [proxy.New](../proxy/proxymanager.go#L104)
surfaced four more gaps, all in logger setup.
**Gap 4 — `LogToStdout` config ignored -- Resolved.**
`cmd/newrouter/main.go` previously hardcoded `proxyLog` / `upstreamLog` to
`os.Stdout`, and the old `muxlog()` helper built a Monitor that nothing wrote
into — so `logToStdout` had no effect and `/logs` (combined history) was always
empty. `server.NewLoggers` ([log.go](../internal/server/log.go)) now replicates
the legacy switch: `proxy` / `upstream` monitors feed `muxLog` (or `io.Discard`)
per `none` / `both` / `upstream` / `proxy`, so `muxLog` accumulates the combined
history. `server.New` takes `muxlog` as a parameter. The loggers outlive config
reloads, so a `LogToStdout` change requires a restart to take effect.
**Gap 5 — `LogTimeFormat` config ignored -- Resolved.**
`cmd/newrouter/main.go` now maps `cfg.LogTimeFormat` to a Go time layout via the
`logTimeFormats` table and applies it (alongside log level) to the proxy and
upstream monitors in `applyLogSettings`, re-applied on config reload.
**Gap 6 — `LogRequests` deprecation warning missing.**
The legacy [proxymanager.go:127](../proxy/proxymanager.go#L127) warns when the
deprecated `logRequests` config key is set. `cmd/newrouter` does not. Low
priority — left open.
**Gap 7 — PID debug log missing -- Resolved.**
`cmd/newrouter/main.go` now logs `PID: %d` at debug level after `applyLogSettings`,
matching [llama-swap.go:71](../llama-swap.go#L71).
---
## Phase X (tbd) — Cutover
- [ ] Swap `llama-swap.go` to delegate to `cmd/newrouter` (or rename newrouter to be the primary entrypoint)
- [ ] Update `Makefile` build targets
- [ ] Update docs / README references to the legacy binary
- [ ] Remove `proxy/proxymanager*.go` and `gin-gonic` dependency once nothing imports them
- [ ] Run `make test-all` and confirm concurrency suite still passes against the new entrypoint
---
## Cross-cutting concerns to keep in mind
- **Single body read**: legacy and newrouter both buffer the request body once. When adding filters (Phase 4c), make sure the buffered bytes flow through `Content-Length` / `transfer-encoding` cleanup as in [proxymanager.go:872](../proxy/proxymanager.go#L872).
- **Streaming flag in context**: legacy stashes `streaming` and `model` under `proxyCtxKey`. The new router uses `ModelKey` / `ModelIDKey` — pick one set of keys and use them consistently for metrics + log handlers.
- **Matrix vs Group divergence**: any handler that calls `swapProcessGroup` or `findGroupByModelName` in the legacy needs a matrix branch too. The new router's `Router` interface already abstracts this — preserve that abstraction rather than reintroducing the branch in every handler.
- **Shutdown ordering**: `httpServer.Shutdown` must drain inflight requests _before_ `Server.Shutdown` tears down processes, otherwise inflight requests 502. Current newrouter ordering at [main.go:87](../cmd/newrouter/main.go#L87) is correct — keep it.
+35 -3
View File
@@ -4,22 +4,41 @@ go 1.26.1
require ( require (
github.com/billziss-gh/golib v0.2.0 github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0 github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fxamacker/cbor/v2 v2.9.1
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/stretchr/testify v1.9.0 github.com/google/jsonschema-go v0.4.3
github.com/klauspost/compress v1.18.5
github.com/shirou/gopsutil/v4 v4.26.4
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
golang.org/x/sync v0.20.0
golang.org/x/sys v0.41.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ebitengine/purego v0.10.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect
@@ -27,19 +46,32 @@ require (
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.45.0 // indirect golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
) )
+78 -8
View File
@@ -1,9 +1,31 @@
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8= github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw= github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
@@ -11,14 +33,20 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -29,28 +57,53 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY=
github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -61,8 +114,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -73,24 +127,40 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
+102
View File
@@ -0,0 +1,102 @@
package cache
import (
"errors"
"sync"
)
var (
ErrExceedsMaxSize = errors.New("item exceeds maximum cache size")
ErrNotFound = errors.New("item not found")
)
type Cache struct {
mu sync.Mutex
items map[int][]byte
order []int
size int
maxSize int
}
func New(maxBytes int) *Cache {
return &Cache{
items: make(map[int][]byte),
order: make([]int, 0),
maxSize: maxBytes,
}
}
func (c *Cache) Add(id int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
dataSize := len(data)
if dataSize > c.maxSize {
return ErrExceedsMaxSize
}
// If key already exists, remove old entry from size and order
if old, exists := c.items[id]; exists {
c.size -= len(old)
c.removeOrder(id)
}
// Evict oldest (FIFO) until room available
for c.size+dataSize > c.maxSize && len(c.order) > 0 {
oldestID := c.order[0]
c.order = c.order[1:]
if evicted, exists := c.items[oldestID]; exists {
c.size -= len(evicted)
delete(c.items, oldestID)
}
}
c.items[id] = data
c.order = append(c.order, id)
c.size += dataSize
return nil
}
func (c *Cache) removeOrder(id int) {
for i, v := range c.order {
if v == id {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
func (c *Cache) Get(id int) ([]byte, error) {
c.mu.Lock()
defer c.mu.Unlock()
data, exists := c.items[id]
if !exists {
return nil, ErrNotFound
}
return data, nil
}
func (c *Cache) Has(id int) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.items[id]
return exists
}
func (c *Cache) Size() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.size
}
func (c *Cache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[int][]byte)
c.order = c.order[:0]
c.size = 0
}
+130
View File
@@ -0,0 +1,130 @@
package cache
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCache_Add(t *testing.T) {
t.Run("adds and retrieves item", func(t *testing.T) {
c := New(1024)
data := []byte("hello")
require.NoError(t, c.Add(1, data))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, data, got)
})
t.Run("returns error for oversized item", func(t *testing.T) {
c := New(10)
err := c.Add(1, make([]byte, 20))
assert.ErrorIs(t, err, ErrExceedsMaxSize)
})
t.Run("evicts oldest items to make room", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, make([]byte, 40)))
require.NoError(t, c.Add(2, make([]byte, 40)))
// Adding item 3 should evict item 1
require.NoError(t, c.Add(3, make([]byte, 40)))
assert.False(t, c.Has(1))
assert.True(t, c.Has(2))
assert.True(t, c.Has(3))
})
t.Run("overwrites existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("old")))
require.NoError(t, c.Add(1, []byte("new")))
got, err := c.Get(1)
require.NoError(t, err)
assert.Equal(t, []byte("new"), got)
assert.Equal(t, 3, c.Size())
})
}
func TestCache_Get(t *testing.T) {
t.Run("returns ErrNotFound for missing key", func(t *testing.T) {
c := New(100)
_, err := c.Get(99)
assert.ErrorIs(t, err, ErrNotFound)
})
}
func TestCache_Has(t *testing.T) {
t.Run("returns true for existing key", func(t *testing.T) {
c := New(100)
require.NoError(t, c.Add(1, []byte("data")))
assert.True(t, c.Has(1))
})
t.Run("returns false for missing key", func(t *testing.T) {
c := New(100)
assert.False(t, c.Has(1))
})
}
func TestCache_Size(t *testing.T) {
t.Run("tracks byte usage", func(t *testing.T) {
c := New(1000)
assert.Equal(t, 0, c.Size())
require.NoError(t, c.Add(1, make([]byte, 100)))
assert.Equal(t, 100, c.Size())
require.NoError(t, c.Add(2, make([]byte, 200)))
assert.Equal(t, 300, c.Size())
})
t.Run("updates on eviction", func(t *testing.T) {
c := New(150)
require.NoError(t, c.Add(1, make([]byte, 100)))
require.NoError(t, c.Add(2, make([]byte, 100)))
// Item 1 should be evicted, size = 100
assert.Equal(t, 100, c.Size())
})
}
func TestCache_Clear(t *testing.T) {
t.Run("removes all items and resets size", func(t *testing.T) {
c := New(1000)
require.NoError(t, c.Add(1, []byte("a")))
require.NoError(t, c.Add(2, []byte("b")))
c.Clear()
assert.Equal(t, 0, c.Size())
assert.False(t, c.Has(1))
assert.False(t, c.Has(2))
})
}
func TestCache_Concurrent(t *testing.T) {
t.Run("concurrent operations are safe", func(t *testing.T) {
c := New(10000)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 100; j++ {
key := id*100 + j
_ = c.Add(key, []byte("data"))
_, _ = c.Get(key)
_ = c.Has(key)
_ = c.Size()
}
}(i)
}
wg.Wait()
})
}
+63
View File
@@ -0,0 +1,63 @@
// Package chain composes http.Handler middleware into a single handler.
//
// A Middleware wraps a downstream http.Handler and may run logic before or
// after delegating to it, or short-circuit by not calling next at all
// (e.g. auth failure, CORS preflight).
package chain
import "net/http"
// Middleware wraps an http.Handler with cross-cutting behavior. It receives
// the next handler in the chain and returns a handler that may call next,
// modify the request/response around it, or short-circuit.
type Middleware func(next http.Handler) http.Handler
// Chain is a reusable middleware stack. Build it once with New (and optionally
// extend per-route with Append), then call Then to wrap each terminal handler
// when registering routes against an http.ServeMux:
//
// api := chain.New(authMW, corsMW)
// mux.Handle("/v1/chat/completions", api.Then(dispatch))
// mux.Handle("/v1/embeddings", api.Append(filters).Then(dispatch))
//
// Middlewares execute left-to-right: mws[0] runs first and may call into
// mws[1], and so on, with the terminal handler invoked last. A middleware
// that does not call next short-circuits the remainder of the chain.
// A zero Chain is valid and applies no middleware.
type Chain struct {
mws []Middleware
}
// New returns a Chain that applies mws left-to-right around any terminal
// handler passed to Then.
func New(mws ...Middleware) Chain {
cp := make([]Middleware, len(mws))
copy(cp, mws)
return Chain{mws: cp}
}
// Append returns a new Chain with mws added after the existing middleware.
// The receiver is not modified, so a base Chain can be safely reused across
// multiple routes that each need different per-route additions.
func (c Chain) Append(mws ...Middleware) Chain {
out := make([]Middleware, 0, len(c.mws)+len(mws))
out = append(out, c.mws...)
out = append(out, mws...)
return Chain{mws: out}
}
// Then wraps final with the chain's middleware and returns the resulting
// handler, suitable for passing to http.ServeMux.Handle. With an empty chain,
// Then returns final unchanged.
func (c Chain) Then(final http.Handler) http.Handler {
h := final
for i := len(c.mws) - 1; i >= 0; i-- {
h = c.mws[i](h)
}
return h
}
// ThenFunc is shorthand for Then(http.HandlerFunc(f)).
func (c Chain) ThenFunc(f http.HandlerFunc) http.Handler {
return c.Then(f)
}
+205
View File
@@ -0,0 +1,205 @@
package chain
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// recordingMiddleware appends tag before calling next and "-after-"+tag after.
func recordingMiddleware(tag string, log *[]string) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
*log = append(*log, tag)
next.ServeHTTP(w, r)
*log = append(*log, "after-"+tag)
})
}
}
func TestChain_HandlersExecuteInDeclaredOrder(t *testing.T) {
var log []string
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "final")
})
h := New(
recordingMiddleware("a", &log),
recordingMiddleware("b", &log),
recordingMiddleware("c", &log),
).Then(final)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
h.ServeHTTP(rec, req)
want := []string{"a", "b", "c", "final", "after-c", "after-b", "after-a"}
if !equal(log, want) {
t.Fatalf("execution order mismatch:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_ShortCircuitsWhenMiddlewareDoesNotCallNext(t *testing.T) {
var log []string
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "final")
})
gate := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "gate")
w.WriteHeader(http.StatusUnauthorized)
})
}
h := New(
recordingMiddleware("outer", &log),
gate,
recordingMiddleware("inner", &log),
).Then(final)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
want := []string{"outer", "gate", "after-outer"}
if !equal(log, want) {
t.Fatalf("short-circuit order mismatch:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_EarlyWritesAreVisibleToLaterMiddleware(t *testing.T) {
header := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Set-By", "outer")
_, _ = io.WriteString(w, "outer:")
next.ServeHTTP(w, r)
})
}
inner := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The outer middleware already set the header; we should see it.
if got := w.Header().Get("X-Set-By"); got != "outer" {
_, _ = io.WriteString(w, "missing-header;")
}
_, _ = io.WriteString(w, "inner:")
next.ServeHTTP(w, r)
})
}
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "final")
})
h := New(header, inner).Then(final)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
body, _ := io.ReadAll(rec.Body)
if got := string(body); !strings.Contains(got, "outer:inner:final") {
t.Fatalf("body: got %q, want it to contain %q", got, "outer:inner:final")
}
if got := rec.Header().Get("X-Set-By"); got != "outer" {
t.Fatalf("header X-Set-By: got %q, want %q", got, "outer")
}
}
func TestChain_ReusableAcrossRoutesViaThen(t *testing.T) {
var log []string
base := New(
recordingMiddleware("auth", &log),
recordingMiddleware("cors", &log),
)
mux := http.NewServeMux()
mux.Handle("/a", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "handler-a")
}))
mux.Handle("/b", base.ThenFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "handler-b")
}))
srv := httptest.NewServer(mux)
defer srv.Close()
for _, path := range []string{"/a", "/b"} {
resp, err := http.Get(srv.URL + path)
if err != nil {
t.Fatalf("GET %s: %v", path, err)
}
resp.Body.Close()
}
want := []string{
"auth", "cors", "handler-a", "after-cors", "after-auth",
"auth", "cors", "handler-b", "after-cors", "after-auth",
}
if !equal(log, want) {
t.Fatalf("reusable chain order mismatch:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_AppendDoesNotMutateReceiver(t *testing.T) {
var log []string
base := New(recordingMiddleware("base", &log))
extended := base.Append(recordingMiddleware("extra", &log))
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "final")
})
// Run extended first to surface any aliasing of the underlying slice.
rec := httptest.NewRecorder()
extended.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
rec = httptest.NewRecorder()
base.Then(final).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
want := []string{
"base", "extra", "final", "after-extra", "after-base",
"base", "final", "after-base",
}
if !equal(log, want) {
t.Fatalf("Append must not mutate the receiver:\n got: %v\nwant: %v", log, want)
}
}
func TestChain_ZeroValueAndEmptyThenAreIdentity(t *testing.T) {
final := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
for name, c := range map[string]Chain{
"zero": {},
"empty": New(),
} {
t.Run(name, func(t *testing.T) {
h := c.Then(final)
if _, ok := h.(http.HandlerFunc); !ok {
t.Fatalf("expected http.HandlerFunc identity, got %T", h)
}
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
if rec.Code != http.StatusTeapot {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusTeapot)
}
})
}
}
func equal(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
@@ -9,6 +9,7 @@ import (
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
"time"
"github.com/billziss-gh/golib/shlex" "github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -124,16 +125,20 @@ type Config struct {
LogToStdout string `yaml:"logToStdout"` LogToStdout string `yaml:"logToStdout"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"` MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
CaptureBuffer int `yaml:"captureBuffer"` CaptureBuffer int `yaml:"captureBuffer"`
Performance PerformanceConfig `yaml:"performance"`
GlobalTTL int `yaml:"globalTTL"` GlobalTTL int `yaml:"globalTTL"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// swap matrix: solver-based alternative to groups // routing is the canonical source for swap/scheduling configuration.
Matrix *MatrixConfig `yaml:"matrix"` // New code must read Routing, never the backwards-compat fields below.
Routing RoutingConfig `yaml:"routing"`
// populated during validation when matrix is configured // Groups and Matrix are permanent backwards-compat input fields for the
ExpandedSets []ExpandedSet `yaml:"-"` // legacy top-level `groups:`/`matrix:` keys. They are normalized into
// Routing by LoadConfigFromReader. New code must not read them directly.
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
Matrix *MatrixConfig `yaml:"matrix"`
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint // for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"` Macros MacroList `yaml:"macros"`
@@ -160,6 +165,35 @@ type Config struct {
Peers PeerDictionaryConfig `yaml:"peers"` Peers PeerDictionaryConfig `yaml:"peers"`
} }
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
type RoutingConfig struct {
Scheduler SchedulerConfig `yaml:"scheduler"`
Router RouterConfig `yaml:"router"`
}
type SchedulerConfig struct {
Use string `yaml:"use"` // default "fifo"
Settings SchedulerSettings `yaml:"settings"`
}
type SchedulerSettings struct {
Fifo FifoConfig `yaml:"fifo"`
}
type FifoConfig struct {
Priority map[string]int `yaml:"priority"` // model ID -> priority, default 0
}
type RouterConfig struct {
Use string `yaml:"use"` // "group" (default) | "matrix"
Settings RouterSettings `yaml:"settings"`
}
type RouterSettings struct {
Groups map[string]GroupConfig `yaml:"groups"`
Matrix *MatrixConfig `yaml:"matrix"`
}
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found { if _, found := c.Models[search]; found {
return search, true return search, true
@@ -220,6 +254,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.HealthCheckTimeout = 15 config.HealthCheckTimeout = 15
} }
// Apply defaults for performance config when section is missing
if config.Performance.Every == 0 {
config.Performance.Every = 5 * time.Second
}
if err = config.Performance.Validate(); err != nil {
return Config{}, fmt.Errorf("performance: %w", err)
}
if config.StartPort < 1 { if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1") return Config{}, fmt.Errorf("startPort must be greater than 1")
} }
@@ -262,6 +304,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
nextPort := config.StartPort nextPort := config.StartPort
for _, modelId := range modelIds { for _, modelId := range modelIds {
modelConfig := config.Models[modelId] modelConfig := config.Models[modelId]
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
// Strip comments from command fields // Strip comments from command fields
modelConfig.Cmd = StripComments(modelConfig.Cmd) modelConfig.Cmd = StripComments(modelConfig.Cmd)
@@ -404,6 +447,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
} }
if err = modelConfig.Capabilities.Validate(); err != nil {
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
}
// Validate SetParamsByID keys and values // Validate SetParamsByID keys and values
for key, paramMap := range modelConfig.Filters.SetParamsByID { for key, paramMap := range modelConfig.Filters.SetParamsByID {
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 { if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
@@ -444,6 +491,34 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.Models[modelId] = modelConfig config.Models[modelId] = modelConfig
} }
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
// the new `routing.router` block are mutually exclusive: a config may use
// either style, never both.
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
rtr := config.Routing.Router
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
if hasTopLevel && hasRouting {
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
}
if !hasTopLevel {
// Both groups and matrix may be defined under routing.router.settings;
// routing.router.use selects which one is active, so there is no conflict.
rs := config.Routing.Router.Settings
switch config.Routing.Router.Use {
case "matrix":
if rs.Matrix == nil {
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
}
config.Matrix = rs.Matrix
case "group", "":
config.Groups = rs.Groups
default:
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
}
}
// groups XOR matrix // groups XOR matrix
if config.Matrix != nil && len(config.Groups) > 0 { if config.Matrix != nil && len(config.Groups) > 0 {
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'") return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
@@ -454,7 +529,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
if err != nil { if err != nil {
return Config{}, fmt.Errorf("matrix: %w", err) return Config{}, fmt.Errorf("matrix: %w", err)
} }
config.ExpandedSets = expandedSets config.Matrix.ExpandedSets = expandedSets
} else { } else {
config = AddDefaultGroupToConfig(config) config = AddDefaultGroupToConfig(config)
@@ -476,6 +551,29 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
} }
// Build the canonical Config.Routing from the effective result. Both legacy
// and new-style configs converge here. The Matrix pointer is shared so
// ExpandedSets stays in one place.
if config.Matrix != nil {
config.Routing.Router.Use = "matrix"
} else {
config.Routing.Router.Use = "group"
}
config.Routing.Router.Settings.Matrix = config.Matrix
config.Routing.Router.Settings.Groups = config.Groups
if config.Routing.Scheduler.Use == "" {
config.Routing.Scheduler.Use = "fifo"
}
if config.Routing.Scheduler.Use != "fifo" {
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
}
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
if _, found := config.RealModelName(modelID); !found {
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
}
}
// Clean up hooks preload // Clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 { if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string var toPreload []string
@@ -646,9 +744,6 @@ func validateMacro(name string, value any) error {
// Validate that value is a scalar type // Validate that value is a scalar type
switch v := value.(type) { switch v := value.(type) {
case string: case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
// Check for self-reference // Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name) macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) { if strings.Contains(v, macroSlug) {
@@ -7,6 +7,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -172,6 +173,25 @@ groups:
IdleConn: 90, IdleConn: 90,
} }
expectedGroups := map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
}
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
@@ -188,47 +208,54 @@ groups:
SendLoadingState: false, SendLoadingState: false,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080", Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"}, Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"}, Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health", CheckEndpoint: "/health",
Name: "Model 1", Name: "Model 1",
Description: "This is model 1", Description: "This is model 1",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081", Proxy: "http://localhost:8081",
Aliases: []string{"m2"}, Aliases: []string{"m2"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
"model3": { "model3": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081", Proxy: "http://localhost:8081",
Aliases: []string{"mthree"}, Aliases: []string{"mthree"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
"model4": { "model4": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082", Proxy: "http://localhost:8082",
CheckEndpoint: "/", CheckEndpoint: "/",
Aliases: []string{}, Aliases: []string{},
Env: []string{}, Env: []string{},
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000, MetricsMaxInMemory: 1000,
CaptureBuffer: 5, CaptureBuffer: 5,
Performance: PerformanceConfig{
Every: 5 * time.Second,
},
Profiles: map[string][]string{ Profiles: map[string][]string{
"test": {"model1", "model2"}, "test": {"model1", "model2"},
}, },
@@ -238,22 +265,16 @@ groups:
"m2": "model2", "m2": "model2",
"mthree": "model3", "mthree": "model3",
}, },
Groups: map[string]GroupConfig{ Groups: expectedGroups,
DEFAULT_GROUP_ID: { Routing: RoutingConfig{
Swap: true, Router: RouterConfig{
Exclusive: true, Use: "group",
Members: []string{"model1", "model3"}, Settings: RouterSettings{
Groups: expectedGroups,
},
}, },
"group1": { Scheduler: SchedulerConfig{
Swap: true, Use: "fifo",
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
}, },
}, },
} }
+60
View File
@@ -0,0 +1,60 @@
package config
import (
"encoding/json"
"os"
"testing"
"github.com/google/jsonschema-go/jsonschema"
"gopkg.in/yaml.v3"
)
// TestConfig_ExampleMatchesSchema validates that config.example.yaml conforms to
// config-schema.json. Both files live at the repository root.
func TestConfig_ExampleMatchesSchema(t *testing.T) {
const (
schemaPath = "../../config-schema.json"
examplePath = "../../config.example.yaml"
)
schemaBytes, err := os.ReadFile(schemaPath)
if err != nil {
t.Fatalf("reading %s: %v", schemaPath, err)
}
var schema jsonschema.Schema
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
t.Fatalf("unmarshalling schema: %v", err)
}
resolved, err := schema.Resolve(&jsonschema.ResolveOptions{
BaseURI: "https://github.com/mostlygeek/llama-swap/",
})
if err != nil {
t.Fatalf("resolving schema: %v", err)
}
exampleBytes, err := os.ReadFile(examplePath)
if err != nil {
t.Fatalf("reading %s: %v", examplePath, err)
}
// Convert YAML to a JSON-like value so numbers and keys match what the
// validator expects.
var yamlValue any
if err := yaml.Unmarshal(exampleBytes, &yamlValue); err != nil {
t.Fatalf("unmarshalling example yaml: %v", err)
}
jsonBytes, err := json.Marshal(yamlValue)
if err != nil {
t.Fatalf("converting example to json: %v", err)
}
var instance any
if err := json.Unmarshal(jsonBytes, &instance); err != nil {
t.Fatalf("unmarshalling example json: %v", err)
}
if err := resolved.Validate(instance); err != nil {
t.Fatalf("config.example.yaml does not match config-schema.json:\n%v", err)
}
}
@@ -1544,3 +1544,174 @@ peers:
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue) assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn) assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
} }
// twoModels is a minimal models block reused by the routing tests below.
const twoModels = `
models:
gemma:
cmd: echo gemma
proxy: http://localhost:8080
qwen:
cmd: echo qwen
proxy: http://localhost:8081
`
func TestConfig_Routing_LegacyTopLevelGroups(t *testing.T) {
yaml := twoModels + `
groups:
g1:
members: [gemma, qwen]
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.Equal(t, "group", cfg.Routing.Router.Use)
// default group injected for orphaned models (none here) still leaves g1
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
}
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
yaml := twoModels + `
matrix:
vars:
g: gemma
q: qwen
sets:
combo: "g | q"
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
}
func TestConfig_Routing_RouterUseMatrix(t *testing.T) {
yaml := twoModels + `
routing:
router:
use: matrix
settings:
matrix:
vars:
g: gemma
q: qwen
sets:
combo: "g | q"
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
require.NotNil(t, cfg.Routing.Router.Settings.Matrix)
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
}
func TestConfig_Routing_RouterUseGroup(t *testing.T) {
yaml := twoModels + `
routing:
router:
use: group
settings:
groups:
g1:
members: [gemma, qwen]
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.Equal(t, "group", cfg.Routing.Router.Use)
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
}
func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
require.NoError(t, err)
assert.Equal(t, "group", cfg.Routing.Router.Use)
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
}
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
yaml := twoModels + `
groups:
g1:
members: [gemma, qwen]
routing:
router:
use: group
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "migrate")
}
func TestConfig_Routing_RouterUseMatrixWithoutSettings(t *testing.T) {
yaml := twoModels + `
routing:
router:
use: matrix
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "routing.router.settings.matrix is not set")
}
// Both groups and matrix may be defined under routing.router.settings;
// routing.router.use selects which one is active.
func TestConfig_Routing_RouterSettingsBothGroupsAndMatrix(t *testing.T) {
yaml := twoModels + `
routing:
router:
use: group
settings:
groups:
g1:
members: [gemma, qwen]
matrix:
sets:
s: "gemma"
`
config, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
// use: group means groups are active and matrix is ignored
assert.Equal(t, "group", config.Routing.Router.Use)
assert.Nil(t, config.Matrix)
assert.Contains(t, config.Groups, "g1")
}
func TestConfig_Routing_UnknownRouter(t *testing.T) {
yaml := twoModels + `
routing:
router:
use: bogus
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown router")
}
func TestConfig_Routing_FifoPriorityUnknownModel(t *testing.T) {
yaml := twoModels + `
routing:
scheduler:
settings:
fifo:
priority:
nope: 5
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown model")
}
func TestConfig_Routing_FifoPriorityKnownModel(t *testing.T) {
yaml := twoModels + `
routing:
scheduler:
settings:
fifo:
priority:
gemma: 5
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.Equal(t, 5, cfg.Routing.Scheduler.Settings.Fifo.Priority["gemma"])
}
@@ -7,6 +7,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -164,6 +165,25 @@ groups:
IdleConn: 90, IdleConn: 90,
} }
expectedGroups := map[string]GroupConfig{
DEFAULT_GROUP_ID: {
Swap: true,
Exclusive: true,
Members: []string{"model1", "model3"},
},
"group1": {
Swap: true,
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
},
}
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
@@ -175,49 +195,56 @@ groups:
SendLoadingState: false, SendLoadingState: false,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}", CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080", Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"}, Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"}, Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health", CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}", CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081", Proxy: "http://localhost:8081",
Aliases: []string{"m2"}, Aliases: []string{"m2"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
"model3": { "model3": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}", CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081", Proxy: "http://localhost:8081",
Aliases: []string{"mthree"}, Aliases: []string{"mthree"},
Env: []string{}, Env: []string{},
CheckEndpoint: "/", CheckEndpoint: "/",
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
"model4": { "model4": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}", CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082", Proxy: "http://localhost:8082",
CheckEndpoint: "/", CheckEndpoint: "/",
Aliases: []string{}, Aliases: []string{},
Env: []string{}, Env: []string{},
SendLoadingState: &modelLoadingState, SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout, Timeouts: defaultTimeout,
HealthCheckTimeout: 15,
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000, MetricsMaxInMemory: 1000,
CaptureBuffer: 5, CaptureBuffer: 5,
Performance: PerformanceConfig{
Every: 5 * time.Second,
},
Profiles: map[string][]string{ Profiles: map[string][]string{
"test": {"model1", "model2"}, "test": {"model1", "model2"},
}, },
@@ -227,22 +254,16 @@ groups:
"m2": "model2", "m2": "model2",
"mthree": "model3", "mthree": "model3",
}, },
Groups: map[string]GroupConfig{ Groups: expectedGroups,
DEFAULT_GROUP_ID: { Routing: RoutingConfig{
Swap: true, Router: RouterConfig{
Exclusive: true, Use: "group",
Members: []string{"model1", "model3"}, Settings: RouterSettings{
Groups: expectedGroups,
},
}, },
"group1": { Scheduler: SchedulerConfig{
Swap: true, Use: "fifo",
Exclusive: false,
Members: []string{"model2"},
},
"forever": {
Swap: true,
Exclusive: false,
Persistent: true,
Members: []string{"model4"},
}, },
}, },
} }
@@ -15,6 +15,9 @@ type MatrixConfig struct {
Var map[string]string `yaml:"vars"` Var map[string]string `yaml:"vars"`
EvictCosts map[string]int `yaml:"evict_costs"` EvictCosts map[string]int `yaml:"evict_costs"`
Sets OrderedSets `yaml:"sets"` Sets OrderedSets `yaml:"sets"`
// populated by ValidateMatrix; not settable from yaml
ExpandedSets []ExpandedSet `yaml:"-"`
} }
// SetEntry is a single named set with its DSL expression. // SetEntry is a single named set with its DSL expression.
@@ -289,7 +289,9 @@ matrix:
cfg, err := LoadConfigFromReader(strings.NewReader(yaml)) cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, cfg.Matrix) assert.NotNil(t, cfg.Matrix)
assert.Len(t, cfg.ExpandedSets, 2) assert.Len(t, cfg.Matrix.ExpandedSets, 2)
assert.Equal(t, "matrix", cfg.Routing.Router.Use)
assert.Len(t, cfg.Routing.Router.Settings.Matrix.ExpandedSets, 2)
// Groups should be empty when matrix is used // Groups should be empty when matrix is used
assert.Empty(t, cfg.Groups) assert.Empty(t, cfg.Groups)
} }
@@ -2,6 +2,7 @@ package config
import ( import (
"errors" "errors"
"fmt"
"runtime" "runtime"
) )
@@ -9,6 +10,47 @@ const (
MODEL_CONFIG_DEFAULT_TTL = -1 MODEL_CONFIG_DEFAULT_TTL = -1
) )
var validModalities = map[string]struct{}{
"text": {},
"audio": {},
"image": {},
}
// ModelCapConfig defines what modalities and features a model supports.
// Used in /v1/models to inform clients. An empty block (all zero values) is
// treated as not configured.
type ModelCapConfig struct {
In []string `yaml:"in"`
Out []string `yaml:"out"`
Tools bool `yaml:"tools"`
Reranker bool `yaml:"reranker"`
Context int `yaml:"context"`
}
// Empty returns true when all fields are at their zero values.
func (c ModelCapConfig) Empty() bool {
return len(c.In) == 0 && len(c.Out) == 0 && !c.Tools && !c.Reranker && c.Context == 0
}
// Validate checks that all modality values are recognized and context is
// non-negative. Returns an error if any value is invalid.
func (c ModelCapConfig) Validate() error {
for _, m := range c.In {
if _, ok := validModalities[m]; !ok {
return fmt.Errorf("capabilities.in: invalid modality %q, must be one of: text, audio, image", m)
}
}
for _, m := range c.Out {
if _, ok := validModalities[m]; !ok {
return fmt.Errorf("capabilities.out: invalid modality %q, must be one of: text, audio, image", m)
}
}
if c.Context < 0 {
return errors.New("capabilities.context: must be >= 0")
}
return nil
}
// TimeoutsConfig holds timeout settings for proxy connections // TimeoutsConfig holds timeout settings for proxy connections
// 0 = no timeout // 0 = no timeout
type TimeoutsConfig struct { type TimeoutsConfig struct {
@@ -54,6 +96,12 @@ type ModelConfig struct {
// Timeout settings for proxy connections // Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"` Timeouts TimeoutsConfig `yaml:"timeouts"`
// Capabilities defines what modalities and features the model supports.
Capabilities ModelCapConfig `yaml:"capabilities"`
// Copy of HealthCheckTimeout from global config
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
} }
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
@@ -152,7 +152,7 @@ models:
stop: stop:
- "<|end|>" - "<|end|>"
- "<|stop|>" - "<|stop|>"
` `
config, err := LoadConfigFromReader(strings.NewReader(content)) config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err) assert.NoError(t, err)
@@ -170,3 +170,167 @@ models:
assert.Equal(t, 0.7, setParams["temperature"]) assert.Equal(t, 0.7, setParams["temperature"])
assert.Equal(t, 0.9, setParams["top_p"]) assert.Equal(t, 0.9, setParams["top_p"])
} }
func TestConfig_ModelCapabilities(t *testing.T) {
t.Run("all fields", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
in:
- text
- audio
- image
out:
- text
- audio
- image
tools: true
context: 32000
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.False(t, mc.Capabilities.Empty())
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.In)
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.Out)
assert.True(t, mc.Capabilities.Tools)
assert.Equal(t, 32000, mc.Capabilities.Context)
})
t.Run("partial fields", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
tools: true
context: 8192
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.False(t, mc.Capabilities.Empty())
assert.Nil(t, mc.Capabilities.In)
assert.Nil(t, mc.Capabilities.Out)
assert.True(t, mc.Capabilities.Tools)
assert.Equal(t, 8192, mc.Capabilities.Context)
})
t.Run("not set", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.True(t, mc.Capabilities.Empty())
})
t.Run("tools false is empty", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
tools: false
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.True(t, mc.Capabilities.Empty())
})
t.Run("reranker true is not empty", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
reranker: true
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.False(t, mc.Capabilities.Empty())
assert.True(t, mc.Capabilities.Reranker)
})
t.Run("reranker false is empty", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
reranker: false
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.True(t, mc.Capabilities.Empty())
})
}
func TestConfig_ModelCapabilities_Validate(t *testing.T) {
t.Run("valid_modalities", func(t *testing.T) {
caps := ModelCapConfig{
In: []string{"text", "image"},
Out: []string{"text", "audio"},
Tools: true,
Context: 100000,
}
assert.NoError(t, caps.Validate())
})
t.Run("empty_is_valid", func(t *testing.T) {
caps := ModelCapConfig{}
assert.NoError(t, caps.Validate())
})
t.Run("invalid_in_modality", func(t *testing.T) {
caps := ModelCapConfig{In: []string{"video"}}
err := caps.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "capabilities.in")
assert.Contains(t, err.Error(), "video")
})
t.Run("invalid_out_modality", func(t *testing.T) {
caps := ModelCapConfig{Out: []string{"video"}}
err := caps.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "capabilities.out")
assert.Contains(t, err.Error(), "video")
})
t.Run("negative_context", func(t *testing.T) {
caps := ModelCapConfig{Context: -1}
err := caps.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "capabilities.context")
})
t.Run("rejects_invalid_at_load", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
in:
- text
- video
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "video")
})
}
+34
View File
@@ -0,0 +1,34 @@
package config
import (
"fmt"
"time"
)
// PerformanceConfig holds configuration for system performance monitoring
type PerformanceConfig struct {
Disabled bool `yaml:"disabled"`
Every time.Duration `yaml:"every"`
}
func (p *PerformanceConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPerformanceConfig PerformanceConfig
defaults := rawPerformanceConfig{
Every: 5 * time.Second,
}
if err := unmarshal(&defaults); err != nil {
return err
}
*p = PerformanceConfig(defaults)
return nil
}
// Validate checks the PerformanceConfig values and returns an error if invalid
func (p *PerformanceConfig) Validate() error {
if p.Every < 5*time.Second {
return fmt.Errorf("every must be at least 5s, got %v", p.Every)
}
return nil
}
+98
View File
@@ -0,0 +1,98 @@
package config
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPerformanceConfig_Defaults(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// When performance section is missing, defaults should be applied
assert.False(t, config.Performance.Disabled)
assert.Equal(t, 5*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_CustomValues(t *testing.T) {
content := `
performance:
enable: true
every: 30s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.False(t, config.Performance.Disabled)
assert.Equal(t, 30*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_Disabled(t *testing.T) {
content := `
performance:
disabled: true
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.True(t, config.Performance.Disabled)
// Duration defaults should still apply
assert.Equal(t, 5*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_PartialValues(t *testing.T) {
content := `
performance:
every: 10s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// enable should default to true
assert.False(t, config.Performance.Disabled)
assert.Equal(t, 10*time.Second, config.Performance.Every)
}
func TestPerformanceConfig_InvalidEvery(t *testing.T) {
content := `
performance:
every: 4s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "every must be at least 5s")
}
func TestPerformanceConfig_ComplexDurations(t *testing.T) {
content := `
performance:
every: 1m30s
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 90*time.Second, config.Performance.Every)
}
@@ -1,54 +1,54 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved. // Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile. // Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event package event
import ( import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
/* /*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
*/ */
func BenchmarkSubscribeConcurrent(b *testing.B) { func BenchmarkSubscribeConcurrent(b *testing.B) {
d := NewDispatcher() d := NewDispatcher()
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
unsub := Subscribe(d, func(ev MyEvent1) {}) unsub := Subscribe(d, func(ev MyEvent1) {})
unsub() unsub()
} }
}) })
} }
func TestDefaultPublish(t *testing.T) { func TestDefaultPublish(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
// Subscribe // Subscribe
var count int64 var count int64
defer On(func(ev MyEvent1) { defer On(func(ev MyEvent1) {
atomic.AddInt64(&count, 1) atomic.AddInt64(&count, 1)
wg.Done() wg.Done()
})() })()
defer OnType(TypeEvent1, func(ev MyEvent1) { defer OnType(TypeEvent1, func(ev MyEvent1) {
atomic.AddInt64(&count, 1) atomic.AddInt64(&count, 1)
wg.Done() wg.Done()
})() })()
// Publish // Publish
wg.Add(4) wg.Add(4)
Emit(MyEvent1{}) Emit(MyEvent1{})
Emit(MyEvent1{}) Emit(MyEvent1{})
// Wait and check // Wait and check
wg.Wait() wg.Wait()
assert.Equal(t, int64(4), count) assert.Equal(t, int64(4), count)
} }
@@ -1,324 +1,324 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved. // Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile. // Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event package event
import ( import (
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestPublish(t *testing.T) { func TestPublish(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
var wg sync.WaitGroup var wg sync.WaitGroup
// Subscribe, must be received in order // Subscribe, must be received in order
var count int64 var count int64
defer Subscribe(d, func(ev MyEvent1) { defer Subscribe(d, func(ev MyEvent1) {
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number) assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
wg.Done() wg.Done()
})() })()
// Publish // Publish
wg.Add(3) wg.Add(3)
Publish(d, MyEvent1{Number: 1}) Publish(d, MyEvent1{Number: 1})
Publish(d, MyEvent1{Number: 2}) Publish(d, MyEvent1{Number: 2})
Publish(d, MyEvent1{Number: 3}) Publish(d, MyEvent1{Number: 3})
// Wait and check // Wait and check
wg.Wait() wg.Wait()
assert.Equal(t, int64(3), count) assert.Equal(t, int64(3), count)
} }
func TestUnsubscribe(t *testing.T) { func TestUnsubscribe(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
assert.Equal(t, 0, d.count(TypeEvent1)) assert.Equal(t, 0, d.count(TypeEvent1))
unsubscribe := Subscribe(d, func(ev MyEvent1) { unsubscribe := Subscribe(d, func(ev MyEvent1) {
// Nothing // Nothing
}) })
assert.Equal(t, 1, d.count(TypeEvent1)) assert.Equal(t, 1, d.count(TypeEvent1))
unsubscribe() unsubscribe()
assert.Equal(t, 0, d.count(TypeEvent1)) assert.Equal(t, 0, d.count(TypeEvent1))
} }
func TestConcurrent(t *testing.T) { func TestConcurrent(t *testing.T) {
const max = 1000000 const max = 1000000
var count int64 var count int64
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
d := NewDispatcher() d := NewDispatcher()
defer Subscribe(d, func(ev MyEvent1) { defer Subscribe(d, func(ev MyEvent1) {
if current := atomic.AddInt64(&count, 1); current == max { if current := atomic.AddInt64(&count, 1); current == max {
wg.Done() wg.Done()
} }
})() })()
// Asynchronously publish // Asynchronously publish
go func() { go func() {
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
Publish(d, MyEvent1{}) Publish(d, MyEvent1{})
} }
}() }()
defer Subscribe(d, func(ev MyEvent1) { defer Subscribe(d, func(ev MyEvent1) {
// Subscriber that does nothing // Subscriber that does nothing
})() })()
wg.Wait() wg.Wait()
assert.Equal(t, max, int(count)) assert.Equal(t, max, int(count))
} }
func TestSubscribeDifferentType(t *testing.T) { func TestSubscribeDifferentType(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
assert.Panics(t, func() { assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {}) SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {}) SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
}) })
} }
func TestPublishDifferentType(t *testing.T) { func TestPublishDifferentType(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
assert.Panics(t, func() { assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {}) SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
Publish(d, MyEvent1{}) Publish(d, MyEvent1{})
}) })
} }
func TestCloseDispatcher(t *testing.T) { func TestCloseDispatcher(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})() defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
assert.NoError(t, d.Close()) assert.NoError(t, d.Close())
assert.Panics(t, func() { assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {}) SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
}) })
} }
func TestMatrix(t *testing.T) { func TestMatrix(t *testing.T) {
const amount = 1000 const amount = 1000
for _, subs := range []int{1, 10, 100} { for _, subs := range []int{1, 10, 100} {
for _, topics := range []int{1, 10} { for _, topics := range []int{1, 10} {
expected := subs * topics * amount expected := subs * topics * amount
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) { t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
var count atomic.Int64 var count atomic.Int64
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(expected) wg.Add(expected)
d := NewDispatcher() d := NewDispatcher()
for i := 0; i < subs; i++ { for i := 0; i < subs; i++ {
for id := 0; id < topics; id++ { for id := 0; id < topics; id++ {
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) { defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
count.Add(1) count.Add(1)
wg.Done() wg.Done()
})() })()
} }
} }
for n := 0; n < amount; n++ { for n := 0; n < amount; n++ {
for id := 0; id < topics; id++ { for id := 0; id < topics; id++ {
go Publish(d, MyEvent3{ID: id}) go Publish(d, MyEvent3{ID: id})
} }
} }
wg.Wait() wg.Wait()
assert.Equal(t, expected, int(count.Load())) assert.Equal(t, expected, int(count.Load()))
}) })
} }
} }
} }
func TestConcurrentSubscriptionRace(t *testing.T) { func TestConcurrentSubscriptionRace(t *testing.T) {
// This test specifically targets the race condition that occurs when multiple // This test specifically targets the race condition that occurs when multiple
// goroutines try to subscribe to different event types simultaneously. // goroutines try to subscribe to different event types simultaneously.
// Without the CAS loop, subscriptions could be lost due to registry corruption. // Without the CAS loop, subscriptions could be lost due to registry corruption.
const numGoroutines = 100 const numGoroutines = 100
const numEventTypes = 50 const numEventTypes = 50
d := NewDispatcher() d := NewDispatcher()
defer d.Close() defer d.Close()
var wg sync.WaitGroup var wg sync.WaitGroup
var receivedCount int64 var receivedCount int64
var subscribedTypes sync.Map // Thread-safe map var subscribedTypes sync.Map // Thread-safe map
wg.Add(numGoroutines) wg.Add(numGoroutines)
// Start multiple goroutines that subscribe to different event types concurrently // Start multiple goroutines that subscribe to different event types concurrently
for i := 0; i < numGoroutines; i++ { for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) { go func(goroutineID int) {
defer wg.Done() defer wg.Done()
// Each goroutine subscribes to a unique event type // Each goroutine subscribes to a unique event type
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
// Subscribe to the event type // Subscribe to the event type
SubscribeTo(d, eventType, func(ev MyEvent3) { SubscribeTo(d, eventType, func(ev MyEvent3) {
atomic.AddInt64(&receivedCount, 1) atomic.AddInt64(&receivedCount, 1)
}) })
// Record that this type was subscribed // Record that this type was subscribed
subscribedTypes.Store(eventType, true) subscribedTypes.Store(eventType, true)
}(i) }(i)
} }
// Wait for all subscriptions to complete // Wait for all subscriptions to complete
wg.Wait() wg.Wait()
// Count the number of unique event types subscribed // Count the number of unique event types subscribed
expectedTypes := 0 expectedTypes := 0
subscribedTypes.Range(func(key, value interface{}) bool { subscribedTypes.Range(func(key, value interface{}) bool {
expectedTypes++ expectedTypes++
return true return true
}) })
// Small delay to ensure all subscriptions are fully processed // Small delay to ensure all subscriptions are fully processed
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Publish events to each subscribed type // Publish events to each subscribed type
subscribedTypes.Range(func(key, value interface{}) bool { subscribedTypes.Range(func(key, value interface{}) bool {
eventType := key.(uint32) eventType := key.(uint32)
Publish(d, MyEvent3{ID: int(eventType)}) Publish(d, MyEvent3{ID: int(eventType)})
return true return true
}) })
// Wait for all events to be processed // Wait for all events to be processed
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
// Verify that we received at least the expected number of events // Verify that we received at least the expected number of events
// (there might be more if multiple goroutines subscribed to the same event type) // (there might be more if multiple goroutines subscribed to the same event type)
received := atomic.LoadInt64(&receivedCount) received := atomic.LoadInt64(&receivedCount)
assert.GreaterOrEqual(t, int(received), expectedTypes, assert.GreaterOrEqual(t, int(received), expectedTypes,
"Should have received at least %d events, got %d", expectedTypes, received) "Should have received at least %d events, got %d", expectedTypes, received)
// Verify that we have the expected number of unique event types // Verify that we have the expected number of unique event types
assert.Equal(t, numEventTypes, expectedTypes, assert.Equal(t, numEventTypes, expectedTypes,
"Should have exactly %d unique event types", numEventTypes) "Should have exactly %d unique event types", numEventTypes)
} }
func TestConcurrentHandlerRegistration(t *testing.T) { func TestConcurrentHandlerRegistration(t *testing.T) {
const numGoroutines = 100 const numGoroutines = 100
// Test concurrent subscriptions to the same event type // Test concurrent subscriptions to the same event type
t.Run("SameEventType", func(t *testing.T) { t.Run("SameEventType", func(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
var handlerCount int64 var handlerCount int64
var wg sync.WaitGroup var wg sync.WaitGroup
// Start multiple goroutines subscribing to the same event type (0x1) // Start multiple goroutines subscribing to the same event type (0x1)
for i := 0; i < numGoroutines; i++ { for i := 0; i < numGoroutines; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) { SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
atomic.AddInt64(&handlerCount, 1) atomic.AddInt64(&handlerCount, 1)
}) })
}() }()
} }
wg.Wait() wg.Wait()
// Verify all handlers were registered by publishing an event // Verify all handlers were registered by publishing an event
atomic.StoreInt64(&handlerCount, 0) atomic.StoreInt64(&handlerCount, 0)
Publish(d, MyEvent1{}) Publish(d, MyEvent1{})
// Small delay to ensure all handlers have executed // Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount), assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
"Not all handlers were registered due to race condition") "Not all handlers were registered due to race condition")
}) })
// Test concurrent subscriptions to different event types // Test concurrent subscriptions to different event types
t.Run("DifferentEventTypes", func(t *testing.T) { t.Run("DifferentEventTypes", func(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
var wg sync.WaitGroup var wg sync.WaitGroup
receivedEvents := make(map[uint32]*int64) receivedEvents := make(map[uint32]*int64)
// Create multiple event types and subscribe concurrently // Create multiple event types and subscribe concurrently
for i := 0; i < numGoroutines; i++ { for i := 0; i < numGoroutines; i++ {
eventType := uint32(100 + i) eventType := uint32(100 + i)
counter := new(int64) counter := new(int64)
receivedEvents[eventType] = counter receivedEvents[eventType] = counter
wg.Add(1) wg.Add(1)
go func(et uint32, cnt *int64) { go func(et uint32, cnt *int64) {
defer wg.Done() defer wg.Done()
SubscribeTo(d, et, func(ev MyEvent3) { SubscribeTo(d, et, func(ev MyEvent3) {
atomic.AddInt64(cnt, 1) atomic.AddInt64(cnt, 1)
}) })
}(eventType, counter) }(eventType, counter)
} }
wg.Wait() wg.Wait()
// Publish events to all types // Publish events to all types
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ { for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
Publish(d, MyEvent3{ID: int(eventType)}) Publish(d, MyEvent3{ID: int(eventType)})
} }
// Small delay to ensure all handlers have executed // Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Verify all event types received their events // Verify all event types received their events
for eventType, counter := range receivedEvents { for eventType, counter := range receivedEvents {
assert.Equal(t, int64(1), atomic.LoadInt64(counter), assert.Equal(t, int64(1), atomic.LoadInt64(counter),
"Event type %d did not receive its event", eventType) "Event type %d did not receive its event", eventType)
} }
}) })
} }
func TestBackpressure(t *testing.T) { func TestBackpressure(t *testing.T) {
d := NewDispatcher() d := NewDispatcher()
d.maxQueue = 10 d.maxQueue = 10
var processedCount int64 var processedCount int64
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) { unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
atomic.AddInt64(&processedCount, 1) atomic.AddInt64(&processedCount, 1)
}) })
defer unsub() defer unsub()
const eventsToPublish = 1000 const eventsToPublish = 1000
for i := 0; i < eventsToPublish; i++ { for i := 0; i < eventsToPublish; i++ {
Publish(d, MyEvent3{ID: 0x200}) Publish(d, MyEvent3{ID: 0x200})
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Verify all events were eventually processed // Verify all events were eventually processed
finalProcessed := atomic.LoadInt64(&processedCount) finalProcessed := atomic.LoadInt64(&processedCount)
assert.Equal(t, int64(eventsToPublish), finalProcessed) assert.Equal(t, int64(eventsToPublish), finalProcessed)
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish) t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
} }
// ------------------------------------- Test Events ------------------------------------- // ------------------------------------- Test Events -------------------------------------
const ( const (
TypeEvent1 = 0x1 TypeEvent1 = 0x1
TypeEvent2 = 0x2 TypeEvent2 = 0x2
) )
type MyEvent1 struct { type MyEvent1 struct {
Number int Number int
} }
func (t MyEvent1) Type() uint32 { return TypeEvent1 } func (t MyEvent1) Type() uint32 { return TypeEvent1 }
type MyEvent2 struct { type MyEvent2 struct {
Text string Text string
} }
func (t MyEvent2) Type() uint32 { return TypeEvent2 } func (t MyEvent2) Type() uint32 { return TypeEvent2 }
type MyEvent3 struct { type MyEvent3 struct {
ID int ID int
} }
func (t MyEvent3) Type() uint32 { return uint32(t.ID) } func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
@@ -1,4 +1,4 @@
package proxy package logmon
import ( import (
"context" "context"
@@ -8,15 +8,25 @@ import (
"sync" "sync"
"time" "time"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/internal/event"
) )
const DataEventID = 0x04
type DataEvent struct {
Data []byte
}
func (e DataEvent) Type() uint32 {
return DataEventID
}
// circularBuffer is a fixed-size circular byte buffer that overwrites // circularBuffer is a fixed-size circular byte buffer that overwrites
// oldest data when full. It provides O(1) writes and O(n) reads. // oldest data when full. It provides O(1) writes and O(n) reads.
type circularBuffer struct { type circularBuffer struct {
data []byte // pre-allocated capacity data []byte
head int // next write position head int
size int // current number of bytes stored (0 to cap) size int
} }
func newCircularBuffer(capacity int) *circularBuffer { func newCircularBuffer(capacity int) *circularBuffer {
@@ -27,8 +37,6 @@ func newCircularBuffer(capacity int) *circularBuffer {
} }
} }
// Write appends bytes to the buffer, overwriting oldest data when full.
// Data is copied into the internal buffer (not stored by reference).
func (cb *circularBuffer) Write(p []byte) { func (cb *circularBuffer) Write(p []byte) {
if len(p) == 0 { if len(p) == 0 {
return return
@@ -36,7 +44,6 @@ func (cb *circularBuffer) Write(p []byte) {
cap := len(cb.data) cap := len(cb.data)
// If input is larger than capacity, only keep the last cap bytes
if len(p) >= cap { if len(p) >= cap {
copy(cb.data, p[len(p)-cap:]) copy(cb.data, p[len(p)-cap:])
cb.head = 0 cb.head = 0
@@ -44,28 +51,22 @@ func (cb *circularBuffer) Write(p []byte) {
return return
} }
// Calculate how much space is available from head to end of buffer
firstPart := cap - cb.head firstPart := cap - cb.head
if firstPart >= len(p) { if firstPart >= len(p) {
// All data fits without wrapping
copy(cb.data[cb.head:], p) copy(cb.data[cb.head:], p)
cb.head = (cb.head + len(p)) % cap cb.head = (cb.head + len(p)) % cap
} else { } else {
// Data wraps around
copy(cb.data[cb.head:], p[:firstPart]) copy(cb.data[cb.head:], p[:firstPart])
copy(cb.data[:len(p)-firstPart], p[firstPart:]) copy(cb.data[:len(p)-firstPart], p[firstPart:])
cb.head = len(p) - firstPart cb.head = len(p) - firstPart
} }
// Update size
cb.size += len(p) cb.size += len(p)
if cb.size > cap { if cb.size > cap {
cb.size = cap cb.size = cap
} }
} }
// GetHistory returns all buffered data in correct order (oldest to newest).
// Returns a new slice (copy), not a view into internal buffer.
func (cb *circularBuffer) GetHistory() []byte { func (cb *circularBuffer) GetHistory() []byte {
if cb.size == 0 { if cb.size == 0 {
return nil return nil
@@ -74,14 +75,11 @@ func (cb *circularBuffer) GetHistory() []byte {
result := make([]byte, cb.size) result := make([]byte, cb.size)
cap := len(cb.data) cap := len(cb.data)
// Calculate start position (oldest data)
start := (cb.head - cb.size + cap) % cap start := (cb.head - cb.size + cap) % cap
if start+cb.size <= cap { if start+cb.size <= cap {
// Data is contiguous, single copy
copy(result, cb.data[start:start+cb.size]) copy(result, cb.data[start:start+cb.size])
} else { } else {
// Data wraps around, two copies
firstPart := cap - start firstPart := cap - start
copy(result[:firstPart], cb.data[start:]) copy(result[:firstPart], cb.data[start:])
copy(result[firstPart:], cb.data[:cb.size-firstPart]) copy(result[firstPart:], cb.data[:cb.size-firstPart])
@@ -90,42 +88,38 @@ func (cb *circularBuffer) GetHistory() []byte {
return result return result
} }
type LogLevel int type Level int
const ( const (
LevelDebug LogLevel = iota LevelDebug Level = iota
LevelInfo LevelInfo
LevelWarn LevelWarn
LevelError LevelError
LogBufferSize = 100 * 1024 BufferSize = 100 * 1024
) )
type LogMonitor struct { type Monitor struct {
eventbus *event.Dispatcher eventbus *event.Dispatcher
mu sync.RWMutex mu sync.RWMutex
buffer *circularBuffer buffer *circularBuffer
bufferMu sync.RWMutex bufferMu sync.RWMutex
// typically this can be os.Stdout
stdout io.Writer stdout io.Writer
// logging levels level Level
level LogLevel prefix string
prefix string
// timestamps
timeFormat string timeFormat string
} }
func NewLogMonitor() *LogMonitor { func New() *Monitor {
return NewLogMonitorWriter(os.Stdout) return NewWriter(os.Stdout)
} }
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor { func NewWriter(stdout io.Writer) *Monitor {
return &LogMonitor{ return &Monitor{
eventbus: event.NewDispatcherConfig(1000), eventbus: event.NewDispatcherConfig(1000),
buffer: nil, // lazy initialized on first Write buffer: nil,
stdout: stdout, stdout: stdout,
level: LevelInfo, level: LevelInfo,
prefix: "", prefix: "",
@@ -133,7 +127,7 @@ func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
} }
} }
func (w *LogMonitor) Write(p []byte) (n int, err error) { func (w *Monitor) Write(p []byte) (n int, err error) {
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
} }
@@ -145,19 +139,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
w.bufferMu.Lock() w.bufferMu.Lock()
if w.buffer == nil { if w.buffer == nil {
w.buffer = newCircularBuffer(LogBufferSize) w.buffer = newCircularBuffer(BufferSize)
} }
w.buffer.Write(p) w.buffer.Write(p)
w.bufferMu.Unlock() w.bufferMu.Unlock()
// Make a copy for broadcast to preserve immutability
bufferCopy := make([]byte, len(p)) bufferCopy := make([]byte, len(p))
copy(bufferCopy, p) copy(bufferCopy, p)
w.broadcast(bufferCopy) w.broadcast(bufferCopy)
return n, nil return n, nil
} }
func (w *LogMonitor) GetHistory() []byte { func (w *Monitor) GetHistory() []byte {
w.bufferMu.RLock() w.bufferMu.RLock()
defer w.bufferMu.RUnlock() defer w.bufferMu.RUnlock()
if w.buffer == nil { if w.buffer == nil {
@@ -168,41 +161,41 @@ func (w *LogMonitor) GetHistory() []byte {
// Clear releases the buffer memory, making it eligible for GC. // Clear releases the buffer memory, making it eligible for GC.
// The buffer will be lazily re-allocated on the next Write. // The buffer will be lazily re-allocated on the next Write.
func (w *LogMonitor) Clear() { func (w *Monitor) Clear() {
w.bufferMu.Lock() w.bufferMu.Lock()
w.buffer = nil w.buffer = nil
w.bufferMu.Unlock() w.bufferMu.Unlock()
} }
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc { func (w *Monitor) OnLogData(callback func(data []byte)) context.CancelFunc {
return event.Subscribe(w.eventbus, func(e LogDataEvent) { return event.Subscribe(w.eventbus, func(e DataEvent) {
callback(e.Data) callback(e.Data)
}) })
} }
func (w *LogMonitor) broadcast(msg []byte) { func (w *Monitor) broadcast(msg []byte) {
event.Publish(w.eventbus, LogDataEvent{Data: msg}) event.Publish(w.eventbus, DataEvent{Data: msg})
} }
func (w *LogMonitor) SetPrefix(prefix string) { func (w *Monitor) SetPrefix(prefix string) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
w.prefix = prefix w.prefix = prefix
} }
func (w *LogMonitor) SetLogLevel(level LogLevel) { func (w *Monitor) SetLogLevel(level Level) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
w.level = level w.level = level
} }
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) { func (w *Monitor) SetLogTimeFormat(timeFormat string) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
w.timeFormat = timeFormat w.timeFormat = timeFormat
} }
func (w *LogMonitor) formatMessage(level string, msg string) []byte { func (w *Monitor) formatMessage(level string, msg string) []byte {
prefix := "" prefix := ""
if w.prefix != "" { if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix) prefix = fmt.Sprintf("[%s] ", w.prefix)
@@ -211,49 +204,38 @@ func (w *LogMonitor) formatMessage(level string, msg string) []byte {
if w.timeFormat != "" { if w.timeFormat != "" {
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat)) timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
} }
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg)) return fmt.Appendf(nil, "%s%s[%s] %s\n", timestamp, prefix, level, msg)
} }
func (w *LogMonitor) log(level LogLevel, msg string) { func (w *Monitor) log(level Level, msg string) {
if level < w.level { if level < w.level {
return return
} }
w.Write(w.formatMessage(level.String(), msg)) w.Write(w.formatMessage(level.String(), msg))
} }
func (w *LogMonitor) Debug(msg string) { func (w *Monitor) Debug(msg string) { w.log(LevelDebug, msg) }
w.log(LevelDebug, msg) func (w *Monitor) Info(msg string) { w.log(LevelInfo, msg) }
} func (w *Monitor) Warn(msg string) { w.log(LevelWarn, msg) }
func (w *Monitor) Error(msg string) { w.log(LevelError, msg) }
func (w *LogMonitor) Info(msg string) { func (w *Monitor) Debugf(format string, args ...any) {
w.log(LevelInfo, msg)
}
func (w *LogMonitor) Warn(msg string) {
w.log(LevelWarn, msg)
}
func (w *LogMonitor) Error(msg string) {
w.log(LevelError, msg)
}
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
w.log(LevelDebug, fmt.Sprintf(format, args...)) w.log(LevelDebug, fmt.Sprintf(format, args...))
} }
func (w *LogMonitor) Infof(format string, args ...interface{}) { func (w *Monitor) Infof(format string, args ...any) {
w.log(LevelInfo, fmt.Sprintf(format, args...)) w.log(LevelInfo, fmt.Sprintf(format, args...))
} }
func (w *LogMonitor) Warnf(format string, args ...interface{}) { func (w *Monitor) Warnf(format string, args ...any) {
w.log(LevelWarn, fmt.Sprintf(format, args...)) w.log(LevelWarn, fmt.Sprintf(format, args...))
} }
func (w *LogMonitor) Errorf(format string, args ...interface{}) { func (w *Monitor) Errorf(format string, args ...any) {
w.log(LevelError, fmt.Sprintf(format, args...)) w.log(LevelError, fmt.Sprintf(format, args...))
} }
func (l LogLevel) String() string { func (l Level) String() string {
switch l { switch l {
case LevelDebug: case LevelDebug:
return "DEBUG" return "DEBUG"
@@ -1,4 +1,4 @@
package proxy package logmon
import ( import (
"bytes" "bytes"
@@ -10,9 +10,8 @@ import (
) )
func TestLogMonitor(t *testing.T) { func TestLogMonitor(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard) logMonitor := NewWriter(io.Discard)
// A WaitGroup is used to wait for all the expected writes to complete
var wg sync.WaitGroup var wg sync.WaitGroup
client1Messages := make([]byte, 0) client1Messages := make([]byte, 0)
@@ -34,10 +33,8 @@ func TestLogMonitor(t *testing.T) {
logMonitor.Write([]byte("2")) logMonitor.Write([]byte("2"))
logMonitor.Write([]byte("3")) logMonitor.Write([]byte("3"))
// wait for all writes to complete
wg.Wait() wg.Wait()
// Check the buffer
expectedHistory := "123" expectedHistory := "123"
history := string(logMonitor.GetHistory()) history := string(logMonitor.GetHistory())
@@ -57,14 +54,11 @@ func TestLogMonitor(t *testing.T) {
} }
func TestWrite_ImmutableBuffer(t *testing.T) { func TestWrite_ImmutableBuffer(t *testing.T) {
// Create a new LogMonitor instance lm := NewWriter(io.Discard)
lm := NewLogMonitorWriter(io.Discard)
// Prepare a message to write
msg := []byte("Hello, World!") msg := []byte("Hello, World!")
lenmsg := len(msg) lenmsg := len(msg)
// Write the message to the LogMonitor
n, err := lm.Write(msg) n, err := lm.Write(msg)
if err != nil { if err != nil {
t.Fatalf("Write failed: %v", err) t.Fatalf("Write failed: %v", err)
@@ -74,13 +68,10 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
t.Errorf("Expected %d bytes written but got %d", lenmsg, n) t.Errorf("Expected %d bytes written but got %d", lenmsg, n)
} }
// Change the original message msg[0] = 'B'
msg[0] = 'B' // This should not affect the buffer
// Get the history from the LogMonitor
history := lm.GetHistory() history := lm.GetHistory()
// Check that the history contains the original message, not the modified one
expected := []byte("Hello, World!") expected := []byte("Hello, World!")
if !bytes.Equal(history, expected) { if !bytes.Equal(history, expected) {
t.Errorf("Expected history to be %q, got %q", expected, history) t.Errorf("Expected history to be %q, got %q", expected, history)
@@ -88,16 +79,12 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
} }
func TestWrite_LogTimeFormat(t *testing.T) { func TestWrite_LogTimeFormat(t *testing.T) {
// Create a new LogMonitor instance lm := NewWriter(io.Discard)
lm := NewLogMonitorWriter(io.Discard)
// Enable timestamps
lm.timeFormat = time.RFC3339 lm.timeFormat = time.RFC3339
// Write the message to the LogMonitor
lm.Info("Hello, World!") lm.Info("Hello, World!")
// Get the history from the LogMonitor
history := lm.GetHistory() history := lm.GetHistory()
timestamp := "" timestamp := ""
@@ -115,48 +102,40 @@ func TestWrite_LogTimeFormat(t *testing.T) {
} }
func TestCircularBuffer_WrapAround(t *testing.T) { func TestCircularBuffer_WrapAround(t *testing.T) {
// Create a small buffer to test wrap-around
cb := newCircularBuffer(10) cb := newCircularBuffer(10)
// Write "hello" (5 bytes)
cb.Write([]byte("hello")) cb.Write([]byte("hello"))
if got := string(cb.GetHistory()); got != "hello" { if got := string(cb.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got) t.Errorf("Expected 'hello', got %q", got)
} }
// Write "world" (5 bytes) - buffer now full
cb.Write([]byte("world")) cb.Write([]byte("world"))
if got := string(cb.GetHistory()); got != "helloworld" { if got := string(cb.GetHistory()); got != "helloworld" {
t.Errorf("Expected 'helloworld', got %q", got) t.Errorf("Expected 'helloworld', got %q", got)
} }
// Write "12345" (5 bytes) - should overwrite "hello"
cb.Write([]byte("12345")) cb.Write([]byte("12345"))
if got := string(cb.GetHistory()); got != "world12345" { if got := string(cb.GetHistory()); got != "world12345" {
t.Errorf("Expected 'world12345', got %q", got) t.Errorf("Expected 'world12345', got %q", got)
} }
// Write data larger than buffer capacity cb.Write([]byte("abcdefghijklmnop"))
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
if got := string(cb.GetHistory()); got != "ghijklmnop" { if got := string(cb.GetHistory()); got != "ghijklmnop" {
t.Errorf("Expected 'ghijklmnop', got %q", got) t.Errorf("Expected 'ghijklmnop', got %q", got)
} }
} }
func TestCircularBuffer_BoundaryConditions(t *testing.T) { func TestCircularBuffer_BoundaryConditions(t *testing.T) {
// Test empty buffer
cb := newCircularBuffer(10) cb := newCircularBuffer(10)
if got := cb.GetHistory(); got != nil { if got := cb.GetHistory(); got != nil {
t.Errorf("Expected nil for empty buffer, got %q", got) t.Errorf("Expected nil for empty buffer, got %q", got)
} }
// Test exact capacity
cb.Write([]byte("1234567890")) cb.Write([]byte("1234567890"))
if got := string(cb.GetHistory()); got != "1234567890" { if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got) t.Errorf("Expected '1234567890', got %q", got)
} }
// Test write exactly at capacity boundary
cb = newCircularBuffer(10) cb = newCircularBuffer(10)
cb.Write([]byte("12345")) cb.Write([]byte("12345"))
cb.Write([]byte("67890")) cb.Write([]byte("67890"))
@@ -166,19 +145,16 @@ func TestCircularBuffer_BoundaryConditions(t *testing.T) {
} }
func TestLogMonitor_LazyInit(t *testing.T) { func TestLogMonitor_LazyInit(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
// Buffer should be nil before any writes
if lm.buffer != nil { if lm.buffer != nil {
t.Error("Expected buffer to be nil before first write") t.Error("Expected buffer to be nil before first write")
} }
// GetHistory should return nil when buffer is nil
if got := lm.GetHistory(); got != nil { if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history before first write, got %q", got) t.Errorf("Expected nil history before first write, got %q", got)
} }
// Write should lazily initialize the buffer
lm.Write([]byte("test")) lm.Write([]byte("test"))
if lm.buffer == nil { if lm.buffer == nil {
@@ -191,15 +167,13 @@ func TestLogMonitor_LazyInit(t *testing.T) {
} }
func TestLogMonitor_Clear(t *testing.T) { func TestLogMonitor_Clear(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
// Write some data
lm.Write([]byte("hello")) lm.Write([]byte("hello"))
if got := string(lm.GetHistory()); got != "hello" { if got := string(lm.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got) t.Errorf("Expected 'hello', got %q", got)
} }
// Clear should release the buffer
lm.Clear() lm.Clear()
if lm.buffer != nil { if lm.buffer != nil {
@@ -212,9 +186,8 @@ func TestLogMonitor_Clear(t *testing.T) {
} }
func TestLogMonitor_ClearAndReuse(t *testing.T) { func TestLogMonitor_ClearAndReuse(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
// Write, clear, then write again
lm.Write([]byte("first")) lm.Write([]byte("first"))
lm.Clear() lm.Clear()
lm.Write([]byte("second")) lm.Write([]byte("second"))
@@ -225,13 +198,12 @@ func TestLogMonitor_ClearAndReuse(t *testing.T) {
} }
func BenchmarkLogMonitorWrite(b *testing.B) { func BenchmarkLogMonitorWrite(b *testing.B) {
// Test data of varying sizes
smallMsg := []byte("small message\n") smallMsg := []byte("small message\n")
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n") mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n") largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
b.Run("SmallWrite", func(b *testing.B) { b.Run("SmallWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
lm.Write(smallMsg) lm.Write(smallMsg)
@@ -239,7 +211,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
}) })
b.Run("MediumWrite", func(b *testing.B) { b.Run("MediumWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
lm.Write(mediumMsg) lm.Write(mediumMsg)
@@ -247,7 +219,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
}) })
b.Run("LargeWrite", func(b *testing.B) { b.Run("LargeWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
lm.Write(largeMsg) lm.Write(largeMsg)
@@ -255,8 +227,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
}) })
b.Run("WithSubscribers", func(b *testing.B) { b.Run("WithSubscribers", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
// Add some subscribers
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
lm.OnLogData(func(data []byte) {}) lm.OnLogData(func(data []byte) {})
} }
@@ -267,8 +238,7 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
}) })
b.Run("GetHistory", func(b *testing.B) { b.Run("GetHistory", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard) lm := NewWriter(io.Discard)
// Pre-populate with data
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
lm.Write(mediumMsg) lm.Write(mediumMsg)
} }
@@ -278,39 +248,3 @@ func BenchmarkLogMonitorWrite(b *testing.B) {
} }
}) })
} }
/*
Benchmark Results - MBP M1 Pro
Before (ring.Ring):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
| MediumWrite (241B) | 76 ns | 264 B | 2 |
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
After (circularBuffer 10KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 67 ns | 240 B | 1 |
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
After (circularBuffer 100KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|-----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 66 ns | 240 B | 1 |
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
Summary:
- GetHistory: 139x faster (10KB), 18x faster (100KB)
- Allocations: reduced from 2 to 1 across all operations
- Small/medium writes: ~1.1-1.6x faster
*/
+92
View File
@@ -0,0 +1,92 @@
package perf
type LUID struct {
LowPart uint32
HighPart int32
}
const maxEnumAdapters = 16
type D3DKMT_ENUMADAPTERS2 struct {
NumAdapters uint32
pAdapters uintptr
}
type D3DKMT_ADAPTERINFO struct {
hAdapter uint32
AdapterLuid LUID
NumOfSources uint32
bPresentMoveRegionsPreferred int32
}
type D3DKMT_OPENADAPTERFROMLUID struct {
AdapterLuid LUID
hAdapter uint32
}
type D3DKMT_CLOSEADAPTER struct {
hAdapter uint32
}
type KMTQUERYADAPTERINFOTYPE int32
const (
KMTQAITYPE_UMDRIVERPRIVATE KMTQUERYADAPTERINFOTYPE = 0
KMTQAITYPE_ADAPTERREGISTRYINFO KMTQUERYADAPTERINFOTYPE = 8
KMTQAITYPE_DRIVERVERSION KMTQUERYADAPTERINFOTYPE = 13
KMTQAITYPE_PHYSICALADAPTERDEVICEIDS KMTQUERYADAPTERINFOTYPE = 31
KMTQAITYPE_NODEPERFDATA KMTQUERYADAPTERINFOTYPE = 61
KMTQAITYPE_ADAPTERPERFDATA KMTQUERYADAPTERINFOTYPE = 62
KMTQAITYPE_ADAPTERPERFDATA_CAPS KMTQUERYADAPTERINFOTYPE = 63
)
type D3DKMT_QUERYADAPTERINFO struct {
hAdapter uint32
Type KMTQUERYADAPTERINFOTYPE
pPrivateDriverData uintptr
PrivateDriverDataSize uint32
}
type D3DKMT_ADAPTER_PERFDATA struct {
PhysicalAdapterIndex uint32
MemoryFrequency uint64
MaxMemoryFrequency uint64
MaxMemoryFrequencyOC uint64
MemoryBandwidth uint64
PCIEBandwidth uint64
FanRPM uint32
Power uint32
Temperature uint32
PowerStateOverride byte
}
type D3DKMT_QUERYSTATISTICS_TYPE int32
const (
D3DKMT_QUERYSTATISTICS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 0
D3DKMT_QUERYSTATISTICS_PROCESS D3DKMT_QUERYSTATISTICS_TYPE = 1
D3DKMT_QUERYSTATISTICS_PROCESS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 2
D3DKMT_QUERYSTATISTICS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 3
D3DKMT_QUERYSTATISTICS_PROCESS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 4
D3DKMT_QUERYSTATISTICS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 5
D3DKMT_QUERYSTATISTICS_PROCESS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 6
D3DKMT_QUERYSTATISTICS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 7
D3DKMT_QUERYSTATISTICS_PROCESS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 8
)
type D3DKMT_ADAPTER_PERFDATACAPS struct {
PhysicalAdapterIndex uint32
MaxMemoryBandwidth uint64
MaxPCIEBandwidth uint64
MaxFanRPM uint32
TemperatureMax uint32
TemperatureWarning uint32
}
type D3DKMT_QUERYSTATISTICS_QUERY_SEGMENT struct {
SegmentId uint32
}
type D3DKMT_QUERYSTATISTICS_QUERY_NODE struct {
NodeId uint32
}
+529
View File
@@ -0,0 +1,529 @@
//go:build windows
package perf
import (
"context"
"encoding/binary"
"fmt"
"sync"
"time"
"unsafe"
"github.com/mostlygeek/llama-swap/internal/logmon"
"golang.org/x/sys/windows"
)
var (
d3dkmDLL *windows.LazyDLL
procEnumAdapters2 *windows.LazyProc
procOpenAdapterFromLuid *windows.LazyProc
procCloseAdapter *windows.LazyProc
procQueryAdapterInfo *windows.LazyProc
procQueryStatistics *windows.LazyProc
d3dkmtInitOnce sync.Once
d3dkmtInitErr error
)
// initD3DKMT lazily loads gdi32.dll and resolves D3DKMT function pointers.
// Safe for concurrent use via sync.Once.
func initD3DKMT() error {
d3dkmtInitOnce.Do(func() {
d3dkmDLL = windows.NewLazySystemDLL("gdi32.dll")
procEnumAdapters2 = d3dkmDLL.NewProc("D3DKMTEnumAdapters2")
procOpenAdapterFromLuid = d3dkmDLL.NewProc("D3DKMTOpenAdapterFromLuid")
procCloseAdapter = d3dkmDLL.NewProc("D3DKMTCloseAdapter")
procQueryAdapterInfo = d3dkmDLL.NewProc("D3DKMTQueryAdapterInfo")
procQueryStatistics = d3dkmDLL.NewProc("D3DKMTQueryStatistics")
for name, p := range map[string]*windows.LazyProc{
"D3DKMTEnumAdapters2": procEnumAdapters2,
"D3DKMTOpenAdapterFromLuid": procOpenAdapterFromLuid,
"D3DKMTCloseAdapter": procCloseAdapter,
"D3DKMTQueryAdapterInfo": procQueryAdapterInfo,
"D3DKMTQueryStatistics": procQueryStatistics,
} {
if err := p.Find(); err != nil {
d3dkmtInitErr = fmt.Errorf("D3DKMT %s not found: %w", name, err)
return
}
}
})
return d3dkmtInitErr
}
// ntstatusCall invokes a D3DKMT function and returns a non-nil error if the
// NTSTATUS result is not STATUS_SUCCESS (0).
func ntstatusCall(proc *windows.LazyProc, arg unsafe.Pointer) error {
ret, _, _ := proc.Call(uintptr(arg))
if ret != 0 {
return fmt.Errorf("NTSTATUS 0x%08x", uint32(ret))
}
return nil
}
// d3dkmEnumerateAdapters enumerates all available graphics adapters via
// D3DKMTEnumAdapters2.
func d3dkmEnumerateAdapters() ([]D3DKMT_ADAPTERINFO, error) {
var adapters [maxEnumAdapters]D3DKMT_ADAPTERINFO
enum := D3DKMT_ENUMADAPTERS2{
NumAdapters: maxEnumAdapters,
pAdapters: uintptr(unsafe.Pointer(&adapters[0])),
}
if err := ntstatusCall(procEnumAdapters2, unsafe.Pointer(&enum)); err != nil {
return nil, fmt.Errorf("EnumAdapters2: %w", err)
}
if enum.NumAdapters == 0 {
return nil, fmt.Errorf("no adapters found")
}
result := make([]D3DKMT_ADAPTERINFO, enum.NumAdapters)
for i := uint32(0); i < enum.NumAdapters; i++ {
result[i] = adapters[i]
}
return result, nil
}
// d3dkmOpenAdapter opens a D3DKMT adapter handle for the given LUID.
func d3dkmOpenAdapter(luid LUID) (uint32, error) {
req := D3DKMT_OPENADAPTERFROMLUID{
AdapterLuid: luid,
}
if err := ntstatusCall(procOpenAdapterFromLuid, unsafe.Pointer(&req)); err != nil {
return 0, fmt.Errorf("OpenAdapterFromLuid: %w", err)
}
return req.hAdapter, nil
}
// d3dkmCloseAdapter closes a previously opened D3DKMT adapter handle.
func d3dkmCloseAdapter(hAdapter uint32) error {
req := D3DKMT_CLOSEADAPTER{hAdapter: hAdapter}
return ntstatusCall(procCloseAdapter, unsafe.Pointer(&req))
}
// d3dkmGetAdapterPerfData queries per-adapter performance data (temperature,
// fan RPM, power, bandwidth) via KMTQAITYPE_ADAPTERPERFDATA.
func d3dkmGetAdapterPerfData(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATA, error) {
var data D3DKMT_ADAPTER_PERFDATA
req := D3DKMT_QUERYADAPTERINFO{
hAdapter: hAdapter,
Type: KMTQAITYPE_ADAPTERPERFDATA,
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
}
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATA): %w", err)
}
return &data, nil
}
// d3dkmGetAdapterPerfDataCaps queries static adapter performance capabilities
// (max fan RPM, temperature limits, max bandwidth) via KMTQAITYPE_ADAPTERPERFDATA_CAPS.
func d3dkmGetAdapterPerfDataCaps(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATACAPS, error) {
var data D3DKMT_ADAPTER_PERFDATACAPS
req := D3DKMT_QUERYADAPTERINFO{
hAdapter: hAdapter,
Type: KMTQAITYPE_ADAPTERPERFDATA_CAPS,
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
}
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATACAPS): %w", err)
}
return &data, nil
}
type queryStatsBuffer struct {
Type int32 // offset 0
AdapterLuid LUID // offset 4
hProcess uintptr // offset 16
// _result mirrors the D3DKMT_QUERYSTATISTICS_RESULT union.
// sizeof(D3DKMT_QUERYSTATISTICS) == 0x328 (808 bytes) on x64.
//
// The C struct layout (x64):
// offset 0: Type (int32, 4 bytes)
// offset 4: AdapterLuid (LUID, 8 bytes)
// offset 12: 4 bytes padding (for 8-byte alignment of hProcess)
// offset 16: hProcess (HANDLE, 8 bytes)
// offset 24: QueryResult (union, 780 bytes — largest member is AdapterInformation)
// offset 804: anonymous input union (QueryNode.NodeId / QuerySegment.SegmentId, 4 bytes)
//
// Previous bug: _result was [776]byte, placing QueryId at offset 800 instead of 804.
// The kernel read NodeId/SegmentId from offset 804 (always zero from _pad),
// causing all NODE and SEGMENT queries to use index 0 regardless of the value
// passed in QueryId. This produced alternating behavior where only GPU util OR
// memory util appeared to work, depending on which test variant happened to put
// non-zero data near offset 804 in the result buffer.
_result [780]byte // offset 24, size 780 — places QueryId at offset 804
QueryId int32 // offset 804 — matches C anonymous union for NodeId/SegmentId
}
func init() {
var buf queryStatsBuffer
if unsafe.Sizeof(buf) != 808 {
panic(fmt.Sprintf("queryStatsBuffer size %d != expected 808 (sizeof D3DKMT_QUERYSTATISTICS on x64)", unsafe.Sizeof(buf)))
}
if unsafe.Offsetof(buf.QueryId) != 804 {
panic(fmt.Sprintf("queryStatsBuffer.QueryId offset %d != expected 804 (C anonymous union offset)", unsafe.Offsetof(buf.QueryId)))
}
var perfData D3DKMT_ADAPTER_PERFDATA
if unsafe.Sizeof(perfData) != 64 {
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATA size %d != expected 64 on x64", unsafe.Sizeof(perfData)))
}
var caps D3DKMT_ADAPTER_PERFDATACAPS
if unsafe.Sizeof(caps) != 40 {
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATACAPS size %d != expected 40 on x64", unsafe.Sizeof(caps)))
}
}
const (
qsoffsetNbSegments = 0
qsoffsetNodeCount = 4
qsoffsetCommitLimit = 0
qsoffsetBytesCommitted = 8
qsoffsetBytesResident = 16
qsoffsetRunningTime = 0
qsoffsetSystemRunningTime = 272
)
// d3dkmQueryAdapterStats returns the number of memory segments and compute
// nodes for the adapter identified by luid.
func d3dkmQueryAdapterStats(luid LUID) (nbSegments uint32, nodeCount uint32, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_ADAPTER),
AdapterLuid: luid,
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(ADAPTER): %w", err)
}
nbSegments = binary.LittleEndian.Uint32(buf._result[qsoffsetNbSegments : qsoffsetNbSegments+4])
nodeCount = binary.LittleEndian.Uint32(buf._result[qsoffsetNodeCount : qsoffsetNodeCount+4])
return nbSegments, nodeCount, nil
}
// d3dkmQuerySegmentStats returns the commit limit (total) and resident
// (used) bytes for the given memory segment of an adapter.
func d3dkmQuerySegmentStats(luid LUID, segmentID uint32) (commitLimit uint64, bytesResident uint64, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_SEGMENT),
AdapterLuid: luid,
QueryId: int32(segmentID),
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(SEGMENT %d): %w", segmentID, err)
}
commitLimit = binary.LittleEndian.Uint64(buf._result[qsoffsetCommitLimit : qsoffsetCommitLimit+8])
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesResident : qsoffsetBytesResident+8])
if bytesResident == 0 {
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesCommitted : qsoffsetBytesCommitted+8])
}
return commitLimit, bytesResident, nil
}
// d3dkmQueryNodeStats returns the global and system running time counters
// (in 100ns units) for the given compute node of an adapter.
func d3dkmQueryNodeStats(luid LUID, nodeID uint32) (runningTime uint64, systemRunningTime uint64, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_NODE),
AdapterLuid: luid,
QueryId: int32(nodeID),
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(NODE %d): %w", nodeID, err)
}
runningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetRunningTime : qsoffsetRunningTime+8])
systemRunningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetSystemRunningTime : qsoffsetSystemRunningTime+8])
return runningTime, systemRunningTime, nil
}
type nodeRunningTimes struct {
Global uint64
System uint64
}
// d3dkmtNodeUtil computes GPU node utilization as a percentage from running
// time deltas. Returns -1 if counters went backwards (wrap/reset), 0 if idle.
func d3dkmtNodeUtil(prevRT, curRT nodeRunningTimes, elapsed100ns int64) float64 {
if curRT.Global < prevRT.Global || curRT.System < prevRT.System {
return -1
}
gd := curRT.Global - prevRT.Global
sd := curRT.System - prevRT.System
if gd > 0 && sd > 0 {
util := float64(gd) / float64(sd)
if util > 1.0 {
util = 1.0
}
return util * 100.0
} else if gd > 0 && elapsed100ns > 0 {
util := float64(gd) / float64(elapsed100ns) * 100.0
if util > 100.0 {
util = 100.0
}
return util
}
return 0
}
// d3dkmtFanPct returns fan speed as a percentage of maxFanRPM, clamped to
// 100%. Returns 0 if maxFanRPM is unavailable or fan is not spinning.
func d3dkmtFanPct(fanRPM, maxFanRPM uint32) float64 {
if maxFanRPM > 0 && fanRPM > 0 {
pct := float64(fanRPM) / float64(maxFanRPM) * 100.0
if pct > 100.0 {
pct = 100.0
}
return pct
}
return 0
}
// d3dkmtPowerW converts power from deci-watts (as reported by D3DKMT) to
// watts. Returns 0 if the power value is zero.
func d3dkmtPowerW(power uint32) float64 {
if power > 0 {
return float64(power) / 10.0
}
return 0
}
// d3dkmtTempC converts temperature from deci-Celsius (as reported by D3DKMT)
// to degrees Celsius.
func d3dkmtTempC(tempDeciC uint32) int {
return int(tempDeciC / 10)
}
type d3dkmtAdapterState struct {
luid LUID
hAdapter uint32
nbSegments uint32
nodeCount uint32
maxFanRPM uint32
prevNodeRT map[uint32]nodeRunningTimes
prevTime time.Time
}
// tryD3DKMT attempts to start GPU monitoring using D3DKMT and optional PDH
// counters. It returns a channel of GpuStat snapshots or an error if no
// usable adapters are found.
func tryD3DKMT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if err := initD3DKMT(); err != nil {
return nil, err
}
adapterInfos, err := d3dkmEnumerateAdapters()
if err != nil {
return nil, err
}
type adapterMeta struct {
luid LUID
nbSegments uint32
nodeCount uint32
maxFanRPM uint32
}
var metaList []adapterMeta
for i, ai := range adapterInfos {
hAdapter, err := d3dkmOpenAdapter(ai.AdapterLuid)
if err != nil {
logger.Debugf("adapter %d: open failed: %s", i, err.Error())
continue
}
nbSegments, nodeCount, err := d3dkmQueryAdapterStats(ai.AdapterLuid)
if err != nil {
logger.Debugf("adapter %d: query stats failed: %s", i, err.Error())
d3dkmCloseAdapter(hAdapter)
continue
}
caps, err := d3dkmGetAdapterPerfDataCaps(hAdapter)
if err != nil {
logger.Debugf("adapter %d: perf caps failed: %s", i, err.Error())
}
d3dkmCloseAdapter(hAdapter)
var maxFanRPM uint32
if caps != nil {
maxFanRPM = caps.MaxFanRPM
}
metaList = append(metaList, adapterMeta{
luid: ai.AdapterLuid,
nbSegments: nbSegments,
nodeCount: nodeCount,
maxFanRPM: maxFanRPM,
})
logger.Debugf("adapter %d: segments=%d nodes=%d fan_max=%d luid=%d:%d", i, nbSegments, nodeCount, maxFanRPM, ai.AdapterLuid.HighPart, ai.AdapterLuid.LowPart)
}
if len(metaList) == 0 {
return nil, fmt.Errorf("no usable D3DKMT adapters found")
}
pdhUtil, pdhErr := initPdhGpuUtil()
if pdhErr != nil {
logger.Debugf("PDH GPU utilization not available: %s", pdhErr.Error())
} else {
logger.Info("using PDH performance counters for GPU utilization")
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
if pdhUtil != nil {
defer pdhUtil.close()
}
var adapters []d3dkmtAdapterState
for _, m := range metaList {
hAdapter, err := d3dkmOpenAdapter(m.luid)
if err != nil {
logger.Debugf("reopen adapter failed: %s", err.Error())
continue
}
adapters = append(adapters, d3dkmtAdapterState{
luid: m.luid,
hAdapter: hAdapter,
nbSegments: m.nbSegments,
nodeCount: m.nodeCount,
maxFanRPM: m.maxFanRPM,
prevNodeRT: make(map[uint32]nodeRunningTimes),
})
}
if len(adapters) == 0 {
return
}
defer func() {
for _, a := range adapters {
d3dkmCloseAdapter(a.hAdapter)
}
}()
for i := range adapters {
a := &adapters[i]
for node := uint32(0); node < a.nodeCount; node++ {
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
if err != nil {
continue
}
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
}
a.prevTime = time.Now()
}
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
stats := make([]GpuStat, 0, len(adapters))
now := time.Now()
var pdhUtilMap map[LUID]float64
if pdhUtil != nil {
pdhUtilMap = pdhUtil.collect()
}
for i := range adapters {
a := &adapters[i]
perfData, err := d3dkmGetAdapterPerfData(a.hAdapter)
if err != nil {
logger.Debugf("adapter %d perfdata: %s", i, err.Error())
continue
}
var memUsedMB, memTotalMB int
for seg := uint32(0); seg < a.nbSegments; seg++ {
limit, resident, err := d3dkmQuerySegmentStats(a.luid, seg)
if err != nil {
continue
}
memUsedMB += int(resident / (1024 * 1024))
memTotalMB += int(limit / (1024 * 1024))
}
var gpuUtil float64
pdhGaveValue := false
if pdhUtilMap != nil {
if util, ok := pdhUtilMap[a.luid]; ok {
gpuUtil = util
pdhGaveValue = true
}
}
if !pdhGaveValue && a.nodeCount > 0 {
elapsedNs := now.Sub(a.prevTime).Nanoseconds()
elapsed100ns := elapsedNs / 100
for node := uint32(0); node < a.nodeCount; node++ {
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
if err != nil {
continue
}
if prevRT, ok := a.prevNodeRT[node]; ok {
if globalRT < prevRT.Global || systemRT < prevRT.System {
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
continue
}
nodeUtil := d3dkmtNodeUtil(prevRT, nodeRunningTimes{Global: globalRT, System: systemRT}, elapsed100ns)
if nodeUtil > gpuUtil {
gpuUtil = nodeUtil
}
}
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
}
a.prevTime = now
}
tempC := d3dkmtTempC(perfData.Temperature)
fanSpeedPct := d3dkmtFanPct(perfData.FanRPM, a.maxFanRPM)
powerDrawW := d3dkmtPowerW(perfData.Power)
var memUtilPct float64
if memTotalMB > 0 {
memUtilPct = float64(memUsedMB) / float64(memTotalMB) * 100.0
}
stats = append(stats, GpuStat{
Timestamp: now,
ID: i,
Name: fmt.Sprintf("GPU %d", i),
TempC: tempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtilPct,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeedPct,
PowerDrawW: powerDrawW,
})
}
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
+98
View File
@@ -0,0 +1,98 @@
//go:build windows
package perf
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestD3dkmtNodeUtil_FullLoad(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 5000, System: 14000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 100.0, got)
}
func TestD3dkmtNodeUtil_PartialUtil(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 3000, System: 14000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 50.0, got)
}
func TestD3dkmtNodeUtil_Identical(t *testing.T) {
prev := nodeRunningTimes{Global: 10000, System: 10000}
cur := nodeRunningTimes{Global: 20000, System: 20000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 100.0, got)
}
func TestD3dkmtNodeUtil_CounterWrap(t *testing.T) {
prev := nodeRunningTimes{Global: 9000, System: 10000}
cur := nodeRunningTimes{Global: 1000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, -1.0, got)
}
func TestD3dkmtNodeUtil_SystemWrap(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 9000}
cur := nodeRunningTimes{Global: 5000, System: 1000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, -1.0, got)
}
func TestD3dkmtNodeUtil_ZeroDelta(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 1000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 0.0, got)
}
func TestD3dkmtNodeUtil_ElapsedFallback(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 6000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 50000)
assert.InDelta(t, 10.0, got, 0.01)
}
func TestD3dkmtFanPct_Normal(t *testing.T) {
assert.Equal(t, 50.0, d3dkmtFanPct(1500, 3000))
}
func TestD3dkmtFanPct_MaxFan(t *testing.T) {
assert.Equal(t, 100.0, d3dkmtFanPct(3000, 3000))
}
func TestD3dkmtFanPct_OverMaxClamped(t *testing.T) {
assert.Equal(t, 100.0, d3dkmtFanPct(4000, 3000))
}
func TestD3dkmtFanPct_ZeroMaxFan(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(1500, 0))
}
func TestD3dkmtFanPct_ZeroFanRPM(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(0, 3000))
}
func TestD3dkmtFanPct_BothZero(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(0, 0))
}
func TestD3dkmtPowerW(t *testing.T) {
assert.Equal(t, 250.0, d3dkmtPowerW(2500))
}
func TestD3dkmtPowerW_Zero(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtPowerW(0))
}
func TestD3dkmtTempC(t *testing.T) {
assert.Equal(t, 65, d3dkmtTempC(650))
}
func TestD3dkmtTempC_Zero(t *testing.T) {
assert.Equal(t, 0, d3dkmtTempC(0))
}
+214
View File
@@ -0,0 +1,214 @@
package perf
import (
"encoding/json"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"time"
)
// ParseNvidiaSmiLine parses a single line from nvidia-smi CSV output.
// Format: index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw
func ParseNvidiaSmiLine(line string) *GpuStat {
fields := strings.Split(line, ",")
if len(fields) < 9 {
return nil
}
id, _ := strconv.Atoi(strings.TrimSpace(fields[0]))
name := strings.TrimSpace(fields[1])
uuid := strings.TrimSpace(fields[2])
tempC, _ := strconv.Atoi(strings.TrimSpace(fields[3]))
gpuUtil, _ := strconv.ParseFloat(strings.TrimSpace(fields[4]), 64)
memUsed, _ := strconv.Atoi(strings.TrimSpace(fields[5]))
memTotal, _ := strconv.Atoi(strings.TrimSpace(fields[6]))
fanSpeed, _ := strconv.ParseFloat(strings.TrimSpace(fields[7]), 64)
powerDraw, _ := strconv.ParseFloat(strings.TrimSpace(fields[8]), 64)
var memUtil float64
if memTotal > 0 {
memUtil = float64(memUsed) / float64(memTotal) * 100
}
return &GpuStat{
Timestamp: time.Now(),
ID: id,
Name: name,
UUID: uuid,
TempC: tempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtil,
MemUsedMB: memUsed,
MemTotalMB: memTotal,
FanSpeedPct: fanSpeed,
PowerDrawW: powerDraw,
}
}
// mactopOutput maps the subset of mactop's headless JSON output that is
// relevant to GpuStat. Note that mactop's memory object is whole-system memory,
// not GPU-attributed; the darwin monitor overlays ioreg's GPU-attributed
// unified memory (see overlayIoregMem) so both backends report consistent
// memory figures.
type mactopOutput struct {
SocMetrics struct {
GPUPower float64 `json:"gpu_power"`
GPUFreq int `json:"gpu_freq_mhz"`
GPUTemp float64 `json:"gpu_temp"`
} `json:"soc_metrics"`
Memory struct {
Total uint64 `json:"total"`
Used uint64 `json:"used"`
} `json:"memory"`
GPUUsage float64 `json:"gpu_usage"`
SystemInfo struct {
Name string `json:"name"`
GPUCoreCount int `json:"gpu_core_count"`
} `json:"system_info"`
Fans []struct {
RPM int `json:"rpm"`
MinRPM int `json:"min_rpm"`
MaxRPM int `json:"max_rpm"`
} `json:"fans"`
Temperatures []struct {
Group string `json:"group"`
Avg float64 `json:"avg_celsius"`
} `json:"temperatures"`
}
// ioreg output uses ` = ` (with spaces) for top-level device properties and
// `=` (no spaces) for values inside nested dictionaries such as
// PerformanceStatistics.
var (
reIoregModel = regexp.MustCompile(`"model"\s*=\s*"([^"]+)"`)
reIoregCoreCount = regexp.MustCompile(`"gpu-core-count"\s*=\s*(\d+)`)
reIoregUtil = regexp.MustCompile(`"Device Utilization %"=(\d+)`)
reIoregMemUsed = regexp.MustCompile(`"In use system memory"=(\d+)`)
)
// ParseIoregOutput parses `ioreg -r -c IOGPU -d 1 -f` output into a GpuStat for
// the Apple Silicon integrated GPU. This is a fallback for when mactop is not
// installed: utilization and used memory are available, but power, temperature,
// and fan speed are not exposed by ioreg. memTotalMB is the unified memory size
// supplied by the caller, since Apple Silicon shares memory between CPU and GPU.
// Returns nil if no GPU device is found in the output.
func ParseIoregOutput(out []byte, memTotalMB int) *GpuStat {
utilMatch := reIoregUtil.FindSubmatch(out)
memMatch := reIoregMemUsed.FindSubmatch(out)
if utilMatch == nil && memMatch == nil {
return nil
}
var gpuUtil float64
if utilMatch != nil {
gpuUtil, _ = strconv.ParseFloat(string(utilMatch[1]), 64)
}
const toMB = 1024 * 1024
var memUsedMB int
if memMatch != nil {
memUsedBytes, _ := strconv.ParseInt(string(memMatch[1]), 10, 64)
memUsedMB = int(memUsedBytes / toMB)
}
var memUtil float64
if memTotalMB > 0 {
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
}
name := "Apple GPU"
if m := reIoregModel.FindSubmatch(out); m != nil {
name = string(m[1])
}
if m := reIoregCoreCount.FindSubmatch(out); m != nil {
if cores, err := strconv.Atoi(string(m[1])); err == nil && cores > 0 {
name = fmt.Sprintf("%s (%d-core GPU)", name, cores)
}
}
return &GpuStat{
Timestamp: time.Now(),
ID: 0,
Name: name,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtil,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
}
}
// ParseMactopLine parses a single line of mactop headless JSON output into a
// GpuStat for the Apple Silicon integrated GPU. Returns nil if the line cannot
// be parsed.
func ParseMactopLine(line string) *GpuStat {
line = strings.TrimSpace(line)
if line == "" {
return nil
}
var out mactopOutput
if err := json.Unmarshal([]byte(line), &out); err != nil {
return nil
}
const toMB = 1024 * 1024
memUsedMB := int(out.Memory.Used / toMB)
memTotalMB := int(out.Memory.Total / toMB)
var memUtil float64
if memTotalMB > 0 {
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
}
name := out.SystemInfo.Name
if name == "" {
name = "Apple GPU"
}
if out.SystemInfo.GPUCoreCount > 0 {
name = fmt.Sprintf("%s (%d-core GPU)", name, out.SystemInfo.GPUCoreCount)
}
// Unified memory has no dedicated VRAM sensor; use the memory temperature
// group when mactop exposes it.
var vramTempC int
for _, t := range out.Temperatures {
if strings.EqualFold(t.Group, "Memory") {
vramTempC = int(math.Round(t.Avg))
break
}
}
// Average fan load across all fans as a percentage of their RPM range.
var fanSpeed float64
var fanCount int
for _, f := range out.Fans {
if f.MaxRPM > f.MinRPM {
pct := float64(f.RPM-f.MinRPM) / float64(f.MaxRPM-f.MinRPM) * 100
if pct < 0 {
pct = 0
}
fanSpeed += pct
fanCount++
}
}
if fanCount > 0 {
fanSpeed /= float64(fanCount)
}
return &GpuStat{
Timestamp: time.Now(),
ID: 0,
Name: name,
TempC: int(math.Round(out.SocMetrics.GPUTemp)),
VramTempC: vramTempC,
GpuUtilPct: out.GPUUsage,
MemUtilPct: memUtil,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeed,
PowerDrawW: out.SocMetrics.GPUPower,
}
}
+206
View File
@@ -0,0 +1,206 @@
package perf
import (
"context"
"errors"
"sync"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/ring"
)
var (
ErrNotImplemented = errors.New("not implemented")
ErrNoGpuTool = errors.New("no GPU monitoring tool available")
)
type Monitor struct {
mutex sync.RWMutex
log *logmon.Monitor
conf config.PerformanceConfig
sysRing ring.Buffer[SysStat]
gpuRing ring.Buffer[[]GpuStat]
stopCtx context.Context
stopCancel context.CancelFunc
sysListeners map[chan SysStat]struct{}
gpuListeners map[chan []GpuStat]struct{}
}
func ringCapacity(c config.PerformanceConfig) int {
n := int(time.Hour / c.Every)
if n < 1 {
n = 1
}
return n
}
func New(c config.PerformanceConfig, logger *logmon.Monitor) (*Monitor, error) {
if c.Every < 100*time.Millisecond {
c.Every = 100 * time.Millisecond
}
if logger == nil {
return nil, errors.New("logger is required")
}
capacity := ringCapacity(c)
return &Monitor{
conf: c,
log: logger,
sysRing: ring.NewBuffer[SysStat](capacity),
gpuRing: ring.NewBuffer[[]GpuStat](capacity),
sysListeners: make(map[chan SysStat]struct{}),
gpuListeners: make(map[chan []GpuStat]struct{}),
}, nil
}
func (m *Monitor) Stop() {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.stopCancel == nil {
return
}
m.stopCancel()
m.stopCancel = nil
}
// UpdateConfig updates the monitor configuration and restarts if changed.
func (m *Monitor) UpdateConfig(newConf config.PerformanceConfig) {
m.mutex.RLock()
changed := m.conf != newConf
m.mutex.RUnlock()
if !changed {
return
}
m.Stop()
m.mutex.Lock()
m.conf = newConf
capacity := ringCapacity(newConf)
m.sysRing = ring.NewBuffer[SysStat](capacity)
m.gpuRing = ring.NewBuffer[[]GpuStat](capacity)
m.mutex.Unlock()
if !newConf.Disabled {
m.Start()
}
}
// Subscribe returns channels to listen to system and GPU stats.
func (m *Monitor) Subscribe() (chan SysStat, chan []GpuStat, func()) {
m.mutex.Lock()
defer m.mutex.Unlock()
sysChan := make(chan SysStat, 1)
gpuChan := make(chan []GpuStat, 1)
m.sysListeners[sysChan] = struct{}{}
m.gpuListeners[gpuChan] = struct{}{}
unsub := func() {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.sysListeners, sysChan)
delete(m.gpuListeners, gpuChan)
}
return sysChan, gpuChan, unsub
}
func (m *Monitor) Start() {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.stopCancel != nil {
return
}
m.stopCtx, m.stopCancel = context.WithCancel(context.Background())
go func() {
tick := time.NewTicker(m.conf.Every)
defer tick.Stop()
for {
select {
case <-m.stopCtx.Done():
return
case <-tick.C:
s, err := ReadSysStats()
if err != nil {
if err != ErrNotImplemented {
m.log.Errorf("failed to read sys stats: %s", err.Error())
}
continue
}
m.mutex.Lock()
m.sysRing.Push(s)
for l := range m.sysListeners {
select {
case l <- s:
default:
}
}
m.mutex.Unlock()
}
}
}()
go func() {
gpuCh, err := getGpuStats(m.stopCtx, m.conf.Every, m.log)
if err != nil {
if errors.Is(err, ErrNotImplemented) || errors.Is(err, ErrNoGpuTool) {
m.log.Infof("GPU monitoring not available: %s", err.Error())
} else {
m.log.Errorf("failed to initialize GPU monitoring: %s", err.Error())
}
return
}
for {
select {
case <-m.stopCtx.Done():
return
case g, ok := <-gpuCh:
if !ok {
m.log.Errorf("failed reading from gpuCh - stopping read goroutine")
return
}
m.mutex.Lock()
m.gpuRing.Push(g)
for l := range m.gpuListeners {
select {
case l <- g:
default:
}
}
m.mutex.Unlock()
}
}
}()
}
// Current returns a copy of the current log of system and GPU stats.
func (m *Monitor) Current() ([]SysStat, []GpuStat) {
m.mutex.RLock()
defer m.mutex.RUnlock()
sysStats := m.sysRing.Slice()
snapshots := m.gpuRing.Slice()
var gpuStats []GpuStat
for _, snapshot := range snapshots {
gpuStats = append(gpuStats, snapshot...)
}
return sysStats, gpuStats
}
func ReadSysStats() (SysStat, error) {
return readSysStats()
}
func GetGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
return getGpuStats(ctx, every, logger)
}
+208
View File
@@ -0,0 +1,208 @@
package perf
import (
"bufio"
"context"
"fmt"
"os/exec"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/load"
"github.com/shirou/gopsutil/v4/mem"
)
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if ch, err := tryMactop(ctx, every, logger); err == nil {
logger.Info("using mactop for GPU monitoring")
return ch, nil
} else {
logger.Debugf("mactop: %s", err.Error())
}
if ch, err := tryIoreg(ctx, every, logger); err == nil {
logger.Info("using ioreg for GPU monitoring")
return ch, nil
} else {
logger.Debugf("ioreg: %s", err.Error())
}
return nil, ErrNoGpuTool
}
// tryIoreg polls `ioreg -r -c IOGPU -d 1 -f` for Apple Silicon GPU stats. It is
// a fallback for when mactop is not installed. ioreg exposes GPU utilization and
// used memory but not power, temperature, or fan speed.
func tryIoreg(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("ioreg"); err != nil {
return nil, ErrNoGpuTool
}
// Verify ioreg actually reports a GPU device before committing to it, so we
// can fall through to ErrNoGpuTool otherwise.
if stat := sampleIoreg(ctx); stat == nil {
return nil, fmt.Errorf("ioreg reported no GPU device")
}
if every < time.Second {
every = time.Second
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
stat := sampleIoreg(ctx)
if stat == nil {
continue
}
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
}()
return ch, nil
}
// sampleIoreg runs ioreg once and parses a single GpuStat, or returns nil.
func sampleIoreg(ctx context.Context) *GpuStat {
out, err := exec.CommandContext(ctx, "ioreg", "-r", "-c", "IOGPU", "-d", "1", "-f").Output()
if err != nil {
return nil
}
var memTotalMB int
if vmStat, err := mem.VirtualMemory(); err == nil {
memTotalMB = int(vmStat.Total / (1024 * 1024))
}
return ParseIoregOutput(out, memTotalMB)
}
// overlayIoregMem replaces a GpuStat's memory fields with the GPU-attributed
// unified memory reported by ioreg. mactop only exposes whole-system memory, so
// without this the mactop and ioreg backends would report different memory
// semantics. It is a no-op when ioreg is unavailable or reports no GPU memory,
// leaving the mactop-supplied values in place.
func overlayIoregMem(ctx context.Context, stat *GpuStat) {
ioStat := sampleIoreg(ctx)
if ioStat == nil {
return
}
stat.MemUsedMB = ioStat.MemUsedMB
stat.MemTotalMB = ioStat.MemTotalMB
stat.MemUtilPct = ioStat.MemUtilPct
}
// tryMactop streams Apple Silicon GPU stats from mactop's headless mode.
// See https://github.com/metaspartan/mactop. mactop emits one JSON object per
// sample to stdout, which we parse into GpuStat.
func tryMactop(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("mactop"); err != nil {
return nil, ErrNoGpuTool
}
// mactop samples power over the interval, so give it at least a second.
intervalMs := int(every.Milliseconds())
if intervalMs < 1000 {
intervalMs = 1000
}
cmd := exec.CommandContext(ctx, "mactop",
"--headless",
"--format", "json",
"--interval", fmt.Sprintf("%d", intervalMs),
)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("mactop stdout pipe failed: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("mactop start failed: %w", err)
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
scanner := bufio.NewScanner(stdout)
// mactop's JSON objects can be large; allow generous line lengths.
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
stat := ParseMactopLine(line)
if stat != nil {
// mactop only reports whole-system memory; overlay ioreg's
// GPU-attributed unified memory so both backends are consistent.
overlayIoregMem(ctx, stat)
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
cmd.Wait()
}()
return ch, nil
}
func readSysStats() (SysStat, error) {
cpuPcts, err := cpu.Percent(0, true)
if err != nil {
return SysStat{}, err
}
vmStat, err := mem.VirtualMemory()
if err != nil {
return SysStat{}, err
}
const toMB = 1024 * 1024
var swapTotalMB, swapUsedMB int
if swapStat, err := mem.SwapMemory(); err == nil {
swapTotalMB = int(swapStat.Total / toMB)
swapUsedMB = int(swapStat.Used / toMB)
}
var loadAvg1, loadAvg5, loadAvg15 float64
if loadStat, err := load.Avg(); err == nil {
loadAvg1 = loadStat.Load1
loadAvg5 = loadStat.Load5
loadAvg15 = loadStat.Load15
}
return SysStat{
Timestamp: time.Now(),
CpuUtilPerCore: cpuPcts,
MemTotalMB: int(vmStat.Total / toMB),
MemUsedMB: int(vmStat.Used / toMB),
MemFreeMB: int(vmStat.Free / toMB),
SwapTotalMB: swapTotalMB,
SwapUsedMB: swapUsedMB,
LoadAvg1: loadAvg1,
LoadAvg5: loadAvg5,
LoadAvg15: loadAvg15,
}, nil
}
+313
View File
@@ -0,0 +1,313 @@
package perf
import (
"io"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestLogger() *logmon.Monitor {
return logmon.NewWriter(io.Discard)
}
func TestNew_DefaultConfig(t *testing.T) {
logger := newTestLogger()
m, err := New(config.PerformanceConfig{}, logger)
require.NoError(t, err)
require.NotNil(t, m)
assert.Equal(t, 100*time.Millisecond, m.conf.Every)
}
func TestNew_CustomConfig(t *testing.T) {
logger := newTestLogger()
cfg := config.PerformanceConfig{
Every: 500 * time.Millisecond,
}
m, err := New(cfg, logger)
require.NoError(t, err)
assert.Equal(t, 500*time.Millisecond, m.conf.Every)
}
func TestNew_NilLogger(t *testing.T) {
m, err := New(config.PerformanceConfig{}, nil)
assert.Error(t, err)
assert.Nil(t, m)
}
func TestNew_BelowMinimumConfig(t *testing.T) {
logger := newTestLogger()
cfg := config.PerformanceConfig{
Every: 1 * time.Millisecond,
}
m, err := New(cfg, logger)
require.NoError(t, err)
assert.Equal(t, 100*time.Millisecond, m.conf.Every)
}
func TestSubscribe_ReturnsChannels(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
sysCh, gpuCh, unsub := m.Subscribe()
defer unsub()
assert.NotNil(t, sysCh)
assert.NotNil(t, gpuCh)
assert.NotNil(t, unsub)
}
func TestSubscribe_UnsubscribeRemovesListeners(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
_, _, unsub := m.Subscribe()
m.mutex.RLock()
assert.Len(t, m.sysListeners, 1)
assert.Len(t, m.gpuListeners, 1)
m.mutex.RUnlock()
unsub()
m.mutex.RLock()
assert.Len(t, m.sysListeners, 0)
assert.Len(t, m.gpuListeners, 0)
m.mutex.RUnlock()
}
func TestSubscribe_MultipleSubscriptions(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
sysCh1, gpuCh1, unsub1 := m.Subscribe()
sysCh2, gpuCh2, unsub2 := m.Subscribe()
defer unsub1()
defer unsub2()
assert.NotEqual(t, sysCh1, sysCh2)
assert.NotEqual(t, gpuCh1, gpuCh2)
m.mutex.RLock()
assert.Len(t, m.sysListeners, 2)
assert.Len(t, m.gpuListeners, 2)
m.mutex.RUnlock()
}
func TestCurrent_EmptyByDefault(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
sysStats, gpuStats := m.Current()
assert.Empty(t, sysStats)
assert.Empty(t, gpuStats)
}
func TestCurrent_ReturnsCopies(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
now := time.Now()
m.sysRing.Push(SysStat{Timestamp: now, MemTotalMB: 1024})
m.gpuRing.Push([]GpuStat{{Timestamp: now, ID: 0, Name: "gpu0"}})
sysStats, gpuStats := m.Current()
assert.Len(t, sysStats, 1)
assert.Len(t, gpuStats, 1)
assert.Equal(t, 1024, sysStats[0].MemTotalMB)
assert.Equal(t, "gpu0", gpuStats[0].Name)
// modifying the returned slice should not affect the original
sysStats[0].MemTotalMB = 999
original, _ := m.Current()
assert.Equal(t, 1024, original[0].MemTotalMB)
}
func TestStart_CollectsSysStats(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
m, err := New(config.PerformanceConfig{Every: 100 * time.Millisecond}, newTestLogger())
require.NoError(t, err)
m.Start()
time.Sleep(350 * time.Millisecond)
m.Stop()
sysStats, _ := m.Current()
assert.NotEmpty(t, sysStats, "expected sys stats to be collected")
}
func TestStart_StopStopsGoroutines(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
m, err := New(config.PerformanceConfig{Every: 100 * time.Millisecond}, newTestLogger())
require.NoError(t, err)
m.Start()
if m.stopCancel == nil {
t.Error("stopCancel should not be nil after Start()")
}
m.Stop()
if m.stopCancel != nil {
t.Error("stopCancel should be nil after Stop()")
}
}
func TestStart_SubscriberReceivesStats(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
m, err := New(config.PerformanceConfig{Every: 100 * time.Millisecond}, newTestLogger())
require.NoError(t, err)
sysCh, _, unsub := m.Subscribe()
defer unsub()
m.Start()
defer m.Stop()
select {
case s := <-sysCh:
assert.False(t, s.Timestamp.IsZero())
assert.NotEmpty(t, s.CpuUtilPerCore)
case <-time.After(500 * time.Millisecond):
t.Fatal("timed out waiting for sys stats")
}
}
func TestReadSysStats(t *testing.T) {
s, err := ReadSysStats()
require.NoError(t, err)
assert.False(t, s.Timestamp.IsZero())
assert.NotEmpty(t, s.CpuUtilPerCore)
assert.Greater(t, s.MemTotalMB, 0)
}
func TestCurrent_ConcurrentAccess(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
m.sysRing.Push(SysStat{Timestamp: time.Now(), MemTotalMB: 1024})
m.gpuRing.Push([]GpuStat{{Timestamp: time.Now(), ID: 0}})
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
sys, gpu := m.Current()
assert.Len(t, sys, 1)
assert.Len(t, gpu, 1)
}()
}
wg.Wait()
}
func TestParseNvidiaSmiLine_ValidLine(t *testing.T) {
line := "0, NVIDIA GeForce RTX 3080, GPU-12345678-1234-1234-1234-123456789abc, 65, 80, 8192, 10240, 75, 250"
stat := ParseNvidiaSmiLine(line)
require.NotNil(t, stat)
assert.Equal(t, 0, stat.ID)
assert.Equal(t, "NVIDIA GeForce RTX 3080", stat.Name)
assert.Equal(t, "GPU-12345678-1234-1234-1234-123456789abc", stat.UUID)
assert.Equal(t, 65, stat.TempC)
assert.Equal(t, 80.0, stat.GpuUtilPct)
assert.Equal(t, 8192, stat.MemUsedMB)
assert.Equal(t, 10240, stat.MemTotalMB)
assert.Equal(t, 75.0, stat.FanSpeedPct)
assert.Equal(t, 250.0, stat.PowerDrawW)
assert.InDelta(t, 80.0, stat.MemUtilPct, 0.01)
}
func TestParseNvidiaSmiLine_ShortLine(t *testing.T) {
line := "0, NVIDIA GPU, GPU-123"
stat := ParseNvidiaSmiLine(line)
assert.Nil(t, stat)
}
func TestParseNvidiaSmiLine_MissingFields(t *testing.T) {
line := "0, NVIDIA GPU, GPU-123, 65, 80, 8192, 10240, 75"
stat := ParseNvidiaSmiLine(line)
assert.Nil(t, stat)
}
func TestParseNvidiaSmiLine_ZeroMemoryTotal(t *testing.T) {
line := "0, NVIDIA GPU, GPU-123, 65, 80, 0, 0, 75, 250"
stat := ParseNvidiaSmiLine(line)
require.NotNil(t, stat)
assert.Equal(t, 0.0, stat.MemUtilPct)
}
const ioregSample = `+-o AGXAcceleratorG13X <class AGXAcceleratorG13X, id 0x1000009a1, registered, matched, active, busy 0 (39191 ms), retain 108>
{
"model" = "Apple M1 Pro"
"gpu-core-count" = 16
"PerformanceStatistics" = {"In use system memory (driver)"=0,"Alloc system memory"=14511046656,"Tiler Utilization %"=34,"recoveryCount"=0,"Renderer Utilization %"=34,"Device Utilization %"=34,"In use system memory"=7688503296}
"IOClass" = "AGXAcceleratorG13X"
}`
func TestParseIoregOutput_ValidOutput(t *testing.T) {
const memTotalMB = 32768
stat := ParseIoregOutput([]byte(ioregSample), memTotalMB)
require.NotNil(t, stat)
assert.Equal(t, 0, stat.ID)
assert.Equal(t, "Apple M1 Pro (16-core GPU)", stat.Name)
assert.Equal(t, 34.0, stat.GpuUtilPct)
assert.Equal(t, 7688503296/(1024*1024), stat.MemUsedMB)
assert.Equal(t, memTotalMB, stat.MemTotalMB)
assert.InDelta(t, float64(stat.MemUsedMB)/memTotalMB*100, stat.MemUtilPct, 0.01)
// Not exposed by ioreg.
assert.Equal(t, 0, stat.TempC)
assert.Equal(t, 0.0, stat.PowerDrawW)
assert.Equal(t, 0.0, stat.FanSpeedPct)
}
func TestParseIoregOutput_NoGpuDevice(t *testing.T) {
stat := ParseIoregOutput([]byte("no gpu here"), 32768)
assert.Nil(t, stat)
}
func TestParseIoregOutput_ZeroMemTotal(t *testing.T) {
stat := ParseIoregOutput([]byte(ioregSample), 0)
require.NotNil(t, stat)
assert.Equal(t, 0.0, stat.MemUtilPct)
}
func TestParseIoregOutput_MissingModel(t *testing.T) {
const out = `"Device Utilization %"=50,"In use system memory"=1048576`
stat := ParseIoregOutput([]byte(out), 1024)
require.NotNil(t, stat)
assert.Equal(t, "Apple GPU", stat.Name)
assert.Equal(t, 50.0, stat.GpuUtilPct)
assert.Equal(t, 1, stat.MemUsedMB)
}
+584
View File
@@ -0,0 +1,584 @@
//go:build unix && !darwin
package perf
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/load"
"github.com/shirou/gopsutil/v4/mem"
psnet "github.com/shirou/gopsutil/v4/net"
)
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if ch, err := tryLACT(ctx, every, logger); err == nil {
logger.Info("using LACT for GPU monitoring")
return ch, nil
} else {
logger.Debugf("LACT: %s", err.Error())
}
if ch, err := tryNvidiaSmi(ctx, every, logger); err == nil {
logger.Info("using nvidia-smi for GPU monitoring")
return ch, nil
} else {
logger.Debugf("nvidia-smi: %s", err.Error())
}
if ch, err := tryRocmSmi(ctx, every, logger); err == nil {
logger.Info("using rocm-smi for GPU monitoring")
return ch, nil
} else {
logger.Debugf("rocm-smi: %s", err.Error())
}
if ch, err := trySysfs(ctx, every, logger); err == nil {
logger.Info("using sysfs for GPU monitoring")
return ch, nil
} else {
logger.Debugf("sysfs: %s", err.Error())
}
return nil, ErrNoGpuTool
}
func tryLACT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
socketPath := lactSocketPath()
if socketPath == "" {
return nil, ErrNoGpuTool
}
conn, err := net.DialTimeout("unix", socketPath, 2*time.Second)
if err != nil {
return nil, fmt.Errorf("cannot connect to LACT socket: %w", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
devices, err := lactListDevices(conn)
if err != nil {
return nil, fmt.Errorf("LACT ListDevices failed: %w", err)
}
if len(devices) == 0 {
return nil, fmt.Errorf("LACT returned no devices")
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
socketPath := lactSocketPath()
if socketPath == "" {
continue
}
conn, err := net.DialTimeout("unix", socketPath, 2*time.Second)
if err != nil {
continue
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
devices, err := lactListDevices(conn)
if err != nil {
conn.Close()
continue
}
stats := make([]GpuStat, 0, len(devices))
for i, d := range devices {
stat, err := lactGetDeviceStats(conn, d.ID, d.Name, i)
if err != nil {
continue
}
if stat.MemTotalMB == 0 {
continue
}
stats = append(stats, stat)
}
conn.Close()
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
func tryNvidiaSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("nvidia-smi"); err != nil {
return nil, ErrNoGpuTool
}
sec := int(every.Seconds())
if sec < 1 {
sec = 1
}
cmd := exec.CommandContext(ctx, "nvidia-smi",
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
"--format=csv,noheader,nounits",
"--loop", fmt.Sprintf("%d", sec),
)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("nvidia-smi stdout pipe failed: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("nvidia-smi start failed: %w", err)
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
stat := ParseNvidiaSmiLine(line)
if stat != nil {
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
cmd.Wait()
}()
return ch, nil
}
func tryRocmSmi(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("rocm-smi"); err != nil {
return nil, ErrNoGpuTool
}
if every < time.Second {
every = time.Second
}
const pollTimeout = 5 * time.Second
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
pollCtx, cancel := context.WithTimeout(ctx, pollTimeout)
cmd := exec.CommandContext(pollCtx, "rocm-smi", "-i", "-P", "-t", "-f", "-u", "--showmemuse", "--showmeminfo", "vram", "--showproductname", "--csv")
out, err := cmd.Output()
timedOut := pollCtx.Err() == context.DeadlineExceeded
cancel()
if err != nil {
if timedOut {
logger.Debug("rocm-smi timed out")
}
continue
}
stats := make([]GpuStat, 0)
scanner := bufio.NewScanner(strings.NewReader(string(out)))
var header string
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
if strings.HasPrefix(line, "device,") {
header = line
continue
}
stat := parseRocmSmiLine(header, line)
if stat != nil {
stats = append(stats, *stat)
}
}
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
func parseRocmSmiLine(header string, line string) *GpuStat {
if header == "" || line == "" {
return nil
}
labels := strings.Split(header, ",")
fields := strings.Split(line, ",")
if len(labels) != len(fields) {
return nil
}
result := &GpuStat{
Timestamp: time.Now(),
ID: -1,
}
var device string
var deviceName string
var cardSeries string
var gfxVersion string
const toMB = 1024 * 1024
for i, col := range labels {
val := strings.TrimSpace(fields[i])
switch col {
case "device":
device = val
id, err := strconv.Atoi(strings.TrimPrefix(val, "card"))
if err != nil {
return nil
}
result.ID = id
case "Device Name":
deviceName = val
case "GUID":
result.UUID = val
case "Temperature (Sensor edge) (C)":
tempC, _ := strconv.ParseFloat(val, 64)
result.TempC = int(tempC)
case "Temperature (Sensor memory) (C)":
vramTempC, _ := strconv.ParseFloat(val, 64)
result.VramTempC = int(vramTempC)
case "Fan speed (%)":
fanSpeed, _ := strconv.ParseFloat(val, 64)
result.FanSpeedPct = fanSpeed
case "Current Socket Graphics Package Power (W)":
fallthrough
case "Average Graphics Package Power (W)":
powerDraw, _ := strconv.ParseFloat(val, 64)
result.PowerDrawW = powerDraw
case "GPU use (%)":
gpuUtil, _ := strconv.ParseFloat(val, 64)
result.GpuUtilPct = gpuUtil
case "GPU Memory Allocated (VRAM%)":
memUtil, _ := strconv.ParseFloat(val, 64)
result.MemUtilPct = memUtil
case "VRAM Total Memory (B)":
memTotal, _ := strconv.ParseUint(val, 10, 64)
result.MemTotalMB = int(memTotal / toMB)
case "VRAM Total Used Memory (B)":
memUsed, _ := strconv.ParseUint(val, 10, 64)
result.MemUsedMB = int(memUsed / toMB)
case "Card Series":
cardSeries = val
case "GFX Version":
gfxVersion = val
}
}
if result.ID == -1 {
return nil
}
name := device
if cardSeries != "" && cardSeries != "N/A" {
name = cardSeries + " " + device + " (" + gfxVersion + ")"
} else if deviceName != "" && deviceName != "N/A" {
name = deviceName + " " + device + " (" + gfxVersion + ")"
}
result.Name = name
return result
}
func trySysfs(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
return nil, ErrNotImplemented
}
func lactSocketPath() string {
if p := os.Getenv("LACT_DAEMON_SOCKET_PATH"); p != "" {
if _, err := os.Stat(p); err == nil {
return p
}
}
rootPath := "/run/lactd.sock"
if _, err := os.Stat(rootPath); err == nil {
return rootPath
}
u, err := user.Current()
if err != nil {
return ""
}
userPath := filepath.Join("/run/user", u.Uid, "lactd.sock")
if _, err := os.Stat(userPath); err == nil {
return userPath
}
return ""
}
type lactRequest struct {
Command string `json:"command"`
Args interface{} `json:"args,omitempty"`
}
type lactResponse struct {
Status string `json:"status"`
Data json.RawMessage `json:"data"`
}
type lactDeviceEntry struct {
ID string `json:"id"`
Name string `json:"name"`
}
type lactDeviceStats struct {
Fan struct {
PwmCurrent *uint8 `json:"pwm_current"`
} `json:"fan"`
Vram struct {
Total *uint64 `json:"total"`
Used *uint64 `json:"used"`
} `json:"vram"`
Power struct {
Average *float64 `json:"average"`
Current *float64 `json:"current"`
} `json:"power"`
Temps map[string]lactTempEntry `json:"temps"`
BusyPercent *uint8 `json:"busy_percent"`
}
type lactTempEntry struct {
Current *float64 `json:"current"`
}
func lactSendRequest(conn net.Conn, req lactRequest) (json.RawMessage, error) {
data, err := json.Marshal(req)
if err != nil {
return nil, err
}
data = append(data, '\n')
if _, err := conn.Write(data); err != nil {
return nil, err
}
reader := bufio.NewReader(conn)
line, err := reader.ReadBytes('\n')
if err != nil {
return nil, err
}
var resp lactResponse
if err := json.Unmarshal(line, &resp); err != nil {
return nil, err
}
if resp.Status != "ok" {
return nil, fmt.Errorf("LACT error: %s", string(resp.Data))
}
return resp.Data, nil
}
func lactListDevices(conn net.Conn) ([]lactDeviceEntry, error) {
data, err := lactSendRequest(conn, lactRequest{Command: "list_devices"})
if err != nil {
return nil, err
}
var devices []lactDeviceEntry
if err := json.Unmarshal(data, &devices); err != nil {
return nil, err
}
return devices, nil
}
func lactGetDeviceStats(conn net.Conn, id string, name string, index int) (GpuStat, error) {
data, err := lactSendRequest(conn, lactRequest{
Command: "device_stats",
Args: struct {
ID string `json:"id"`
}{ID: id},
})
if err != nil {
return GpuStat{}, err
}
var stats lactDeviceStats
if err := json.Unmarshal(data, &stats); err != nil {
return GpuStat{}, err
}
var memUsedMB, memTotalMB int
if stats.Vram.Used != nil {
memUsedMB = int(*stats.Vram.Used / 1024 / 1024)
}
if stats.Vram.Total != nil {
memTotalMB = int(*stats.Vram.Total / 1024 / 1024)
}
var memUtil float64
if memTotalMB > 0 {
memUtil = float64(memUsedMB) / float64(memTotalMB) * 100
}
var gpuUtil float64
if stats.BusyPercent != nil {
gpuUtil = float64(*stats.BusyPercent)
}
var fanSpeed float64
if stats.Fan.PwmCurrent != nil {
fanSpeed = float64(*stats.Fan.PwmCurrent) / 255.0 * 100.0
}
var powerDraw float64
if stats.Power.Average != nil && *stats.Power.Average > 0 {
powerDraw = *stats.Power.Average
} else if stats.Power.Current != nil {
powerDraw = *stats.Power.Current
}
var tempC int
if t, ok := stats.Temps["edge"]; ok && t.Current != nil {
tempC = int(*t.Current)
} else if t, ok := stats.Temps["junction"]; ok && t.Current != nil {
tempC = int(*t.Current)
} else {
for _, t := range stats.Temps {
if t.Current != nil {
tempC = int(*t.Current)
break
}
}
}
var vramTempC int
// nvidia uses "VRAM", amd "mem"
for _, key := range []string{"mem", "VRAM"} {
if t, ok := stats.Temps[key]; ok && t.Current != nil && *t.Current > 0 {
vramTempC = int(*t.Current)
break
}
}
return GpuStat{
Timestamp: time.Now(),
ID: index,
Name: name,
UUID: id,
TempC: tempC,
VramTempC: vramTempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtil,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeed,
PowerDrawW: powerDraw,
}, nil
}
func readSysfs() ([]GpuStat, error) {
return nil, ErrNotImplemented
}
func readSysStats() (SysStat, error) {
cpuPcts, err := cpu.Percent(0, true)
if err != nil {
return SysStat{}, err
}
vmStat, err := mem.VirtualMemory()
if err != nil {
return SysStat{}, err
}
const toMB = 1024 * 1024
var swapTotalMB, swapUsedMB int
if swapStat, err := mem.SwapMemory(); err == nil {
swapTotalMB = int(swapStat.Total / toMB)
swapUsedMB = int(swapStat.Used / toMB)
}
var loadAvg1, loadAvg5, loadAvg15 float64
if loadStat, err := load.Avg(); err == nil {
loadAvg1 = loadStat.Load1
loadAvg5 = loadStat.Load5
loadAvg15 = loadStat.Load15
}
netIO := make([]NetIOStat, 0)
if ioCounters, err := psnet.IOCounters(true); err == nil {
for _, ioc := range ioCounters {
if ioc.Name == "lo" {
continue
}
netIO = append(netIO, NetIOStat{
Name: ioc.Name,
BytesRecv: ioc.BytesRecv,
BytesSent: ioc.BytesSent,
})
}
}
return SysStat{
Timestamp: time.Now(),
CpuUtilPerCore: cpuPcts,
MemTotalMB: int(vmStat.Total / toMB),
MemUsedMB: int(vmStat.Used / toMB),
MemFreeMB: int(vmStat.Free / toMB),
SwapTotalMB: swapTotalMB,
SwapUsedMB: swapUsedMB,
LoadAvg1: loadAvg1,
LoadAvg5: loadAvg5,
LoadAvg15: loadAvg15,
NetIO: netIO,
}, nil
}
+121
View File
@@ -0,0 +1,121 @@
package perf
import (
"bufio"
"context"
"fmt"
"os/exec"
"strings"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
"github.com/shirou/gopsutil/v4/net"
)
func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if ch, err := tryNvidiaSmiWindows(ctx, every, logger); err == nil {
logger.Info("using nvidia-smi for GPU monitoring")
return ch, nil
} else {
logger.Debugf("nvidia-smi: %s", err.Error())
}
if ch, err := tryD3DKMT(ctx, every, logger); err == nil {
logger.Info("using D3DKMT for GPU monitoring")
return ch, nil
} else {
logger.Debugf("D3DKMT: %s", err.Error())
}
return nil, ErrNoGpuTool
}
// tryNvidiaSmiWindows starts nvidia-smi in loop mode on Windows and returns
// a channel receiving GPU stat snapshots. Returns ErrNoGpuTool if nvidia-smi
// is not available.
func tryNvidiaSmiWindows(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if _, err := exec.LookPath("nvidia-smi"); err != nil {
return nil, ErrNoGpuTool
}
sec := int(every.Seconds())
if sec < 1 {
sec = 1
}
cmd := exec.CommandContext(ctx, "nvidia-smi",
"--query-gpu=index,name,uuid,temperature.gpu,utilization.gpu,memory.used,memory.total,fan.speed,power.draw",
"--format=csv,noheader,nounits",
"--loop", fmt.Sprintf("%d", sec),
)
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("nvidia-smi stdout pipe failed: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("nvidia-smi start failed: %w", err)
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
stat := ParseNvidiaSmiLine(line)
if stat != nil {
select {
case ch <- []GpuStat{*stat}:
default:
}
}
}
cmd.Wait()
}()
return ch, nil
}
func readSysStats() (SysStat, error) {
cpuPcts, err := cpu.Percent(0, true)
if err != nil {
return SysStat{}, err
}
vmStat, err := mem.VirtualMemory()
if err != nil {
return SysStat{}, err
}
const toMB = 1024 * 1024
netIO := make([]NetIOStat, 0)
if ioCounters, err := net.IOCounters(true); err == nil {
for _, ioc := range ioCounters {
netIO = append(netIO, NetIOStat{
Name: ioc.Name,
BytesRecv: ioc.BytesRecv,
BytesSent: ioc.BytesSent,
})
}
}
return SysStat{
Timestamp: time.Now(),
CpuUtilPerCore: cpuPcts,
MemTotalMB: int(vmStat.Total / toMB),
MemUsedMB: int(vmStat.Used / toMB),
MemFreeMB: int(vmStat.Free / toMB),
NetIO: netIO,
}, nil
}
+159
View File
@@ -0,0 +1,159 @@
//go:build windows
package perf
import (
"fmt"
"strconv"
"strings"
"unsafe"
"golang.org/x/sys/windows"
)
var (
pdhDLL = windows.NewLazySystemDLL("pdh.dll")
procPdhOpenQuery = pdhDLL.NewProc("PdhOpenQueryW")
procPdhAddEnglishCounter = pdhDLL.NewProc("PdhAddEnglishCounterW")
procPdhCollectQueryData = pdhDLL.NewProc("PdhCollectQueryData")
procPdhGetFormattedCounterArray = pdhDLL.NewProc("PdhGetFormattedCounterArrayW")
procPdhCloseQuery = pdhDLL.NewProc("PdhCloseQuery")
)
const (
pdhFmtDouble = 0x00000200
pdhMoreData = 0x800007D2
pdhNoData = 0x800007D5
)
type pdhCounterValue struct {
CStatus uint32
DblVal float64
}
type pdhCounterValueItem struct {
SzName *uint16
FmtValue pdhCounterValue
}
func init() {
var item pdhCounterValueItem
if unsafe.Sizeof(item) != 24 {
panic(fmt.Sprintf("pdhCounterValueItem size %d != expected 24 on x64", unsafe.Sizeof(item)))
}
}
type pdhGpuUtil struct {
query uintptr
counter uintptr
}
// initPdhGpuUtil creates a PDH query for the GPU Engine utilization counter.
// Returns nil with an error if PDH or the counter is unavailable.
func initPdhGpuUtil() (*pdhGpuUtil, error) {
var query uintptr
if ret, _, _ := procPdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query))); ret != 0 {
return nil, fmt.Errorf("PdhOpenQuery: 0x%x", ret)
}
path, _ := windows.UTF16PtrFromString(`\GPU Engine(*)\Utilization Percentage`)
var counter uintptr
if ret, _, _ := procPdhAddEnglishCounter.Call(
query, uintptr(unsafe.Pointer(path)), 0, uintptr(unsafe.Pointer(&counter)),
); ret != 0 {
procPdhCloseQuery.Call(query)
return nil, fmt.Errorf("PdhAddEnglishCounter(GPU Engine): 0x%x", ret)
}
procPdhCollectQueryData.Call(query)
return &pdhGpuUtil{query: query, counter: counter}, nil
}
// close releases the PDH query handle.
func (p *pdhGpuUtil) close() {
if p.query != 0 {
procPdhCloseQuery.Call(p.query)
p.query = 0
}
}
// collect reads the PDH counter and returns a map of adapter LUID to
// aggregated GPU utilization percentage, summed across all engine instances
// per adapter and clamped to 100%.
func (p *pdhGpuUtil) collect() map[LUID]float64 {
ret, _, _ := procPdhCollectQueryData.Call(p.query)
if ret != 0 && ret != pdhNoData {
return nil
}
var bufSize uint32
var itemCount uint32
ret, _, _ = procPdhGetFormattedCounterArray.Call(
p.counter, pdhFmtDouble,
uintptr(unsafe.Pointer(&bufSize)),
uintptr(unsafe.Pointer(&itemCount)),
0,
)
if ret != pdhMoreData || itemCount == 0 {
return nil
}
buf := make([]byte, bufSize)
ret, _, _ = procPdhGetFormattedCounterArray.Call(
p.counter, pdhFmtDouble,
uintptr(unsafe.Pointer(&bufSize)),
uintptr(unsafe.Pointer(&itemCount)),
uintptr(unsafe.Pointer(&buf[0])),
)
if ret != 0 {
return nil
}
itemSize := uint32(unsafe.Sizeof(pdhCounterValueItem{}))
result := make(map[LUID]float64)
for i := uint32(0); i < itemCount; i++ {
item := (*pdhCounterValueItem)(unsafe.Pointer(&buf[i*itemSize]))
if item.FmtValue.CStatus != 0 {
continue
}
luid, ok := parsePdhLuid(windows.UTF16PtrToString(item.SzName))
if !ok {
continue
}
result[luid] += item.FmtValue.DblVal
}
for luid := range result {
if result[luid] > 100.0 {
result[luid] = 100.0
}
}
return result
}
// parsePdhLuid extracts the adapter LUID (high and low parts) from a PDH
// GPU Engine instance name (e.g. "pid_1234_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute").
func parsePdhLuid(name string) (LUID, bool) {
idx := strings.Index(name, "luid_0x")
if idx < 0 {
return LUID{}, false
}
rest := name[idx+7:]
parts := strings.SplitN(rest, "_", 4)
if len(parts) < 3 {
return LUID{}, false
}
hp, err := strconv.ParseUint(parts[0], 16, 32)
if err != nil {
return LUID{}, false
}
lpStr := strings.TrimPrefix(parts[1], "0x")
lp, err := strconv.ParseUint(lpStr, 16, 32)
if err != nil {
return LUID{}, false
}
return LUID{LowPart: uint32(lp), HighPart: int32(hp)}, true
}
+53
View File
@@ -0,0 +1,53 @@
//go:build windows
package perf
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParsePdhLuid_Valid(t *testing.T) {
name := `pid_25312_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x000148BF), got.LowPart)
assert.Equal(t, int32(0x00000000), got.HighPart)
}
func TestParsePdhLuid_ValidNvidia(t *testing.T) {
name := `pid_1388_luid_0x00000000_0x00011372_phys_0_eng_8_engtype_Compute_1`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x00011372), got.LowPart)
assert.Equal(t, int32(0x00000000), got.HighPart)
}
func TestParsePdhLuid_NonZeroHighPart(t *testing.T) {
name := `pid_1234_luid_0x00000001_0x0000C85A_phys_0_eng_5_engtype_Copy`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x0000C85A), got.LowPart)
assert.Equal(t, int32(0x00000001), got.HighPart)
}
func TestParsePdhLuid_InvalidNoLuid(t *testing.T) {
_, ok := parsePdhLuid("invalid_string_without_luid")
assert.False(t, ok)
}
func TestParsePdhLuid_InvalidEmpty(t *testing.T) {
_, ok := parsePdhLuid("")
assert.False(t, ok)
}
func TestParsePdhLuid_InvalidHex(t *testing.T) {
_, ok := parsePdhLuid("pid_1234_luid_0xZZZZ_0xGGGG_phys_0")
assert.False(t, ok)
}
func TestParsePdhLuid_ShortAfterLuid(t *testing.T) {
_, ok := parsePdhLuid("pid_1234_luid_0x00000000")
assert.False(t, ok)
}
+129
View File
@@ -0,0 +1,129 @@
package perf
import (
"fmt"
"net/http"
"sort"
"strings"
)
const mbToBytes = int64(1024 * 1024)
// MetricsHandler returns an http.HandlerFunc serving Prometheus text format metrics
// with the most recent system and GPU stats.
func (m *Monitor) MetricsHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sysStats, gpuStats := m.Current()
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
if len(sysStats) > 0 {
writeSysMetrics(w, sysStats[len(sysStats)-1])
}
if len(gpuStats) > 0 {
writeGpuMetrics(w, latestPerGPU(gpuStats))
}
}
}
func writeSysMetrics(w http.ResponseWriter, s SysStat) {
fmt.Fprintf(w, "# HELP llamaswap_cpu_util_percent CPU utilization per core (0-100)\n")
fmt.Fprintf(w, "# TYPE llamaswap_cpu_util_percent gauge\n")
for i, pct := range s.CpuUtilPerCore {
fmt.Fprintf(w, "llamaswap_cpu_util_percent{core=\"%d\"} %g\n", i, pct)
}
fmt.Fprintf(w, "# HELP llamaswap_memory_total_bytes Total memory in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_memory_total_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_memory_total_bytes %d\n", int64(s.MemTotalMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_memory_used_bytes Used memory in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_memory_used_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_memory_used_bytes %d\n", int64(s.MemUsedMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_memory_free_bytes Free memory in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_memory_free_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_memory_free_bytes %d\n", int64(s.MemFreeMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_swap_total_bytes Total swap in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_swap_total_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_swap_total_bytes %d\n", int64(s.SwapTotalMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_swap_used_bytes Used swap in bytes\n")
fmt.Fprintf(w, "# TYPE llamaswap_swap_used_bytes gauge\n")
fmt.Fprintf(w, "llamaswap_swap_used_bytes %d\n", int64(s.SwapUsedMB)*mbToBytes)
fmt.Fprintf(w, "# HELP llamaswap_load_average Load average\n")
fmt.Fprintf(w, "# TYPE llamaswap_load_average gauge\n")
fmt.Fprintf(w, "llamaswap_load_average{interval=\"1m\"} %g\n", s.LoadAvg1)
fmt.Fprintf(w, "llamaswap_load_average{interval=\"5m\"} %g\n", s.LoadAvg5)
fmt.Fprintf(w, "llamaswap_load_average{interval=\"15m\"} %g\n", s.LoadAvg15)
if len(s.NetIO) > 0 {
fmt.Fprintf(w, "# HELP llamaswap_network_bytes_total Total network bytes transferred\n")
fmt.Fprintf(w, "# TYPE llamaswap_network_bytes_total counter\n")
for _, io := range s.NetIO {
iface := sanitizeLabel(io.Name)
fmt.Fprintf(w, "llamaswap_network_bytes_total{interface=\"%s\",direction=\"recv\"} %d\n", iface, io.BytesRecv)
fmt.Fprintf(w, "llamaswap_network_bytes_total{interface=\"%s\",direction=\"sent\"} %d\n", iface, io.BytesSent)
}
}
}
func writeGpuMetrics(w http.ResponseWriter, gpus []GpuStat) {
if len(gpus) == 0 {
return
}
type gpuMetric struct {
help string
name string
value func(GpuStat) float64
}
metrics := []gpuMetric{
{"GPU temperature in Celsius", "llamaswap_gpu_temperature_celsius", func(g GpuStat) float64 { return float64(g.TempC) }},
{"GPU VRAM temperature in Celsius", "llamaswap_gpu_vram_temperature_celsius", func(g GpuStat) float64 { return float64(g.VramTempC) }},
{"GPU utilization percent (0-100)", "llamaswap_gpu_util_percent", func(g GpuStat) float64 { return g.GpuUtilPct }},
{"GPU memory utilization percent (0-100)", "llamaswap_gpu_memory_util_percent", func(g GpuStat) float64 { return g.MemUtilPct }},
{"GPU memory used in bytes", "llamaswap_gpu_memory_used_bytes", func(g GpuStat) float64 { return float64(g.MemUsedMB) * float64(mbToBytes) }},
{"GPU memory total in bytes", "llamaswap_gpu_memory_total_bytes", func(g GpuStat) float64 { return float64(g.MemTotalMB) * float64(mbToBytes) }},
{"GPU fan speed percent (0-100)", "llamaswap_gpu_fan_speed_percent", func(g GpuStat) float64 { return g.FanSpeedPct }},
{"GPU power draw in watts", "llamaswap_gpu_power_draw_watts", func(g GpuStat) float64 { return g.PowerDrawW }},
}
for _, m := range metrics {
fmt.Fprintf(w, "# HELP %s %s\n", m.name, m.help)
fmt.Fprintf(w, "# TYPE %s gauge\n", m.name)
for _, g := range gpus {
if g.UUID != "" {
fmt.Fprintf(w, "%s{id=\"%d\",name=\"%s\",uuid=\"%s\"} %g\n",
m.name, g.ID, sanitizeLabel(g.Name), sanitizeLabel(g.UUID), m.value(g))
} else {
fmt.Fprintf(w, "%s{id=\"%d\",name=\"%s\"} %g\n",
m.name, g.ID, sanitizeLabel(g.Name), m.value(g))
}
}
}
}
// latestPerGPU returns the most recent GpuStat for each GPU ID, sorted by ID.
func latestPerGPU(stats []GpuStat) []GpuStat {
latest := make(map[int]GpuStat)
for _, g := range stats {
if prev, ok := latest[g.ID]; !ok || g.Timestamp.After(prev.Timestamp) {
latest[g.ID] = g
}
}
result := make([]GpuStat, 0, len(latest))
for _, g := range latest {
result = append(result, g)
}
sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID })
return result
}
// sanitizeLabel escapes characters that are invalid in Prometheus label values.
func sanitizeLabel(s string) string {
return strings.NewReplacer(`"`, `\"`, `\`, `\\`, "\n", `\n`).Replace(s)
}
+248
View File
@@ -0,0 +1,248 @@
package perf
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSanitizeLabel(t *testing.T) {
tests := []struct {
input string
want string
}{
{"normal", "normal"},
{"", ""},
{`with"quote`, `with\"quote`},
{`with\backslash`, `with\\backslash`},
{"with\nnewline", `with\nnewline`},
{`"both\n"`, `\"both\\n\"`},
}
for _, tc := range tests {
assert.Equal(t, tc.want, sanitizeLabel(tc.input), "input: %q", tc.input)
}
}
func TestLatestPerGPU_Empty(t *testing.T) {
result := latestPerGPU(nil)
assert.Empty(t, result)
}
func TestLatestPerGPU_Single(t *testing.T) {
now := time.Now()
stats := []GpuStat{{ID: 0, Name: "gpu0", Timestamp: now}}
result := latestPerGPU(stats)
require.Len(t, result, 1)
assert.Equal(t, "gpu0", result[0].Name)
}
func TestLatestPerGPU_PicksLatest(t *testing.T) {
earlier := time.Now().Add(-time.Second)
later := time.Now()
stats := []GpuStat{
{ID: 0, Name: "old", TempC: 50, Timestamp: earlier},
{ID: 0, Name: "new", TempC: 70, Timestamp: later},
}
result := latestPerGPU(stats)
require.Len(t, result, 1)
assert.Equal(t, "new", result[0].Name)
assert.Equal(t, 70, result[0].TempC)
}
func TestLatestPerGPU_MultipleGPUsSortedByID(t *testing.T) {
now := time.Now()
stats := []GpuStat{
{ID: 2, Name: "gpu2", Timestamp: now},
{ID: 0, Name: "gpu0", Timestamp: now},
{ID: 1, Name: "gpu1", Timestamp: now},
}
result := latestPerGPU(stats)
require.Len(t, result, 3)
assert.Equal(t, 0, result[0].ID)
assert.Equal(t, 1, result[1].ID)
assert.Equal(t, 2, result[2].ID)
}
func TestWriteSysMetrics(t *testing.T) {
rec := httptest.NewRecorder()
s := SysStat{
CpuUtilPerCore: []float64{10.5, 20.0},
MemTotalMB: 8192,
MemUsedMB: 4096,
MemFreeMB: 4096,
SwapTotalMB: 2048,
SwapUsedMB: 512,
LoadAvg1: 1.5,
LoadAvg5: 1.2,
LoadAvg15: 0.9,
NetIO: []NetIOStat{
{Name: "eth0", BytesRecv: 1000, BytesSent: 2000},
},
}
writeSysMetrics(rec, s)
body := rec.Body.String()
assert.Contains(t, body, `llamaswap_cpu_util_percent{core="0"} 10.5`)
assert.Contains(t, body, `llamaswap_cpu_util_percent{core="1"} 20`)
assert.Contains(t, body, "llamaswap_memory_total_bytes 8589934592")
assert.Contains(t, body, "llamaswap_memory_used_bytes 4294967296")
assert.Contains(t, body, "llamaswap_memory_free_bytes 4294967296")
assert.Contains(t, body, "llamaswap_swap_total_bytes 2147483648")
assert.Contains(t, body, "llamaswap_swap_used_bytes 536870912")
assert.Contains(t, body, `llamaswap_load_average{interval="1m"} 1.5`)
assert.Contains(t, body, `llamaswap_load_average{interval="5m"} 1.2`)
assert.Contains(t, body, `llamaswap_load_average{interval="15m"} 0.9`)
assert.Contains(t, body, `llamaswap_network_bytes_total{interface="eth0",direction="recv"} 1000`)
assert.Contains(t, body, `llamaswap_network_bytes_total{interface="eth0",direction="sent"} 2000`)
}
func TestWriteSysMetrics_NoNetIO(t *testing.T) {
rec := httptest.NewRecorder()
writeSysMetrics(rec, SysStat{CpuUtilPerCore: []float64{5.0}})
body := rec.Body.String()
assert.NotContains(t, body, "llamaswap_network_bytes_total")
}
func TestWriteGpuMetrics_Empty(t *testing.T) {
rec := httptest.NewRecorder()
writeGpuMetrics(rec, nil)
assert.Empty(t, rec.Body.String())
}
func TestWriteGpuMetrics(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{
{
ID: 0,
Name: "NVIDIA RTX 4090",
UUID: "GPU-1234",
TempC: 75,
GpuUtilPct: 85.5,
MemUtilPct: 60.0,
MemUsedMB: 8192,
MemTotalMB: 24576,
FanSpeedPct: 55.0,
PowerDrawW: 300.5,
},
}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.Contains(t, body, `llamaswap_gpu_temperature_celsius{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 75`)
assert.Contains(t, body, `llamaswap_gpu_vram_temperature_celsius{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 0`)
assert.Contains(t, body, `llamaswap_gpu_util_percent{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 85.5`)
assert.Contains(t, body, `llamaswap_gpu_memory_util_percent{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 60`)
assert.Contains(t, body, `llamaswap_gpu_memory_used_bytes{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"}`)
assert.Contains(t, body, `llamaswap_gpu_memory_total_bytes{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"}`)
assert.Contains(t, body, `llamaswap_gpu_fan_speed_percent{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 55`)
assert.Contains(t, body, `llamaswap_gpu_power_draw_watts{id="0",name="NVIDIA RTX 4090",uuid="GPU-1234"} 300.5`)
}
func TestWriteGpuMetrics_VramTemp(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{
{ID: 0, Name: "AMD RX 7900", UUID: "GPU-5678", TempC: 70, VramTempC: 85},
}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.Contains(t, body, `llamaswap_gpu_temperature_celsius{id="0",name="AMD RX 7900",uuid="GPU-5678"} 70`)
assert.Contains(t, body, `llamaswap_gpu_vram_temperature_celsius{id="0",name="AMD RX 7900",uuid="GPU-5678"} 85`)
}
func TestWriteGpuMetrics_EmptyUUID(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{{ID: 3, Name: "AMD RX 7900", UUID: ""}}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.NotContains(t, body, "uuid=")
assert.Contains(t, body, `name="AMD RX 7900"`)
}
func TestWriteGpuMetrics_LabelSanitization(t *testing.T) {
rec := httptest.NewRecorder()
gpus := []GpuStat{
{ID: 0, Name: `GPU "special"`, UUID: "uuid\nline"},
}
writeGpuMetrics(rec, gpus)
body := rec.Body.String()
assert.Contains(t, body, `name="GPU \"special\""`)
assert.Contains(t, body, `uuid="uuid\nline"`)
}
func TestMetricsHandler_ContentType(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
assert.Equal(t, "text/plain; version=0.0.4; charset=utf-8", rec.Header().Get("Content-Type"))
}
func TestMetricsHandler_EmptyStats(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Empty(t, strings.TrimSpace(rec.Body.String()))
}
func TestMetricsHandler_WithSysStats(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
m.sysRing.Push(SysStat{Timestamp: time.Now(), CpuUtilPerCore: []float64{25.0}, MemTotalMB: 4096, MemUsedMB: 2048, MemFreeMB: 2048})
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
body := rec.Body.String()
assert.Contains(t, body, "llamaswap_cpu_util_percent")
assert.Contains(t, body, "llamaswap_memory_total_bytes")
}
func TestMetricsHandler_UsesLatestSysStat(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
now := time.Now()
m.sysRing.Push(SysStat{Timestamp: now.Add(-time.Second), MemTotalMB: 1000})
m.sysRing.Push(SysStat{Timestamp: now, MemTotalMB: 8192})
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
body := rec.Body.String()
// 8192 MB = 8589934592 bytes
assert.Contains(t, body, "llamaswap_memory_total_bytes 8589934592")
}
func TestMetricsHandler_WithGpuStats(t *testing.T) {
m, err := New(config.PerformanceConfig{}, newTestLogger())
require.NoError(t, err)
m.gpuRing.Push([]GpuStat{{ID: 0, Name: "TestGPU", UUID: "uuid-0", TempC: 65, Timestamp: time.Now()}})
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec := httptest.NewRecorder()
m.MetricsHandler()(rec, req)
body := rec.Body.String()
assert.Contains(t, body, "llamaswap_gpu_temperature_celsius")
assert.Contains(t, body, `name="TestGPU"`)
}
+40
View File
@@ -0,0 +1,40 @@
package perf
import "time"
type GpuStat struct {
Timestamp time.Time `json:"timestamp"`
ID int `json:"id"`
Name string `json:"name"`
UUID string `json:"uuid"`
TempC int `json:"temp_c"`
VramTempC int `json:"vram_temp_c"`
GpuUtilPct float64 `json:"gpu_util_pct"`
MemUtilPct float64 `json:"mem_util_pct"`
MemUsedMB int `json:"mem_used_mb"`
MemTotalMB int `json:"mem_total_mb"`
FanSpeedPct float64 `json:"fan_speed_pct"`
PowerDrawW float64 `json:"power_draw_w"`
}
type NetIOStat struct {
Name string `json:"name"`
BytesRecv uint64 `json:"bytes_recv"`
BytesSent uint64 `json:"bytes_sent"`
}
type SysStat struct {
Timestamp time.Time `json:"timestamp"`
CpuUtilPerCore []float64 `json:"cpu_util_per_core"`
MemTotalMB int `json:"mem_total_mb"`
MemUsedMB int `json:"mem_used_mb"`
MemFreeMB int `json:"mem_free_mb"`
SwapTotalMB int `json:"swap_total_mb"`
SwapUsedMB int `json:"swap_used_mb"`
LoadAvg1 float64 `json:"load_avg_1"`
LoadAvg5 float64 `json:"load_avg_5"`
LoadAvg15 float64 `json:"load_avg_15"`
NetIO []NetIOStat `json:"net_io"`
}
+49
View File
@@ -0,0 +1,49 @@
package process
import (
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var simpleResponderPath string
func skipIfNoSimpleResponder(t *testing.T) {
t.Helper()
if _, err := os.Stat(simpleResponderPath); os.IsNotExist(err) {
t.Skipf("simple-responder not found at %s, run `make simple-responder`", simpleResponderPath)
}
}
func TestMain(m *testing.M) {
goos := runtime.GOOS
goarch := runtime.GOARCH
if goos == "windows" {
simpleResponderPath = filepath.Join("..", "..", "build", "simple-responder.exe")
} else {
simpleResponderPath = filepath.Join("..", "..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
m.Run()
}
func getFreePort(t *testing.T) int {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("getFreePort: %v", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}
func simpleResponderCmd(t *testing.T, args ...string) (string, int) {
port := getFreePort(t)
cmdPath := filepath.ToSlash(simpleResponderPath)
base := []string{cmdPath, fmt.Sprintf("-port %d", port)}
base = append(base, args...)
return strings.Join(base, " "), port
}
+49
View File
@@ -0,0 +1,49 @@
package process
import (
"context"
"net/http"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type Process interface {
// Run starts the process blocks until the process is terminated.
// The timeout parameter controls how long to wait for the process to get
// to a ready state to process traffic
Run(timeout time.Duration) error
// WaitReady blocks until the process is ready to serve requests
// or the context is cancelled. It returns nil when the process is ready
WaitReady(context.Context) error
// Stop blocks until the process has terminated. It returns nil when
// the process terminated as expected (exit 0)
Stop(timeout time.Duration) error
// State returns the current state of the process
// Note: this is a snapshot of the state at the time of the call
// and may change at any time after the call returns.
State() ProcessState
// ServeHTTP forwards requests to the underlying process
// Calling it when the process is not ready will result in a
// 503 response with a body indicating it is a llama-swap-error
ServeHTTP(http.ResponseWriter, *http.Request)
// Logger returns the monitor that captures this process's stdout/stderr.
Logger() *logmon.Monitor
}
+684
View File
@@ -0,0 +1,684 @@
package process
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"os/exec"
"strings"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
)
var ErrStartAborted = fmt.Errorf("aborted")
// cmdWaitDelay is the upper bound the runtime will wait for child I/O to
// drain after the process exits before force-closing the stdout/stderr
// pipes. Required so that cmd.Wait() returns even when a forked grandchild
// inherits and holds the pipes open (e.g. a shell wrapper that backgrounds
// the real binary). killProcess sends the stop signal directly (not via the
// cmd context), so this delay is measured from process exit rather than from
// the stop request, and stays independent of the caller's graceful timeout.
const cmdWaitDelay = 10 * time.Second
// parentCancelGraceTimeout is the graceful timeout used when the process is
// torn down because parentCtx was cancelled (final router teardown or app
// shutdown). In the normal flow the process has already been stopped via
// Stop() by this point, so killProcess is a no-op kill; the short grace just
// bounds the rare case where a process is still alive when its context is cut.
const parentCancelGraceTimeout = time.Second
type runReq struct {
timeout time.Duration
respond chan error
}
type stopReq struct {
timeout time.Duration
respond chan error
}
type waitReadyReq struct {
respond chan error
}
type startResult struct {
cmd *exec.Cmd
cmdDone chan struct{}
cancel context.CancelFunc
handlerFn http.HandlerFunc
err error
}
type ProcessCommand struct {
id string
config config.ModelConfig
parentCtx context.Context
processLogger *logmon.Monitor
proxyLogger *logmon.Monitor
// waitDelay is assigned to cmd.WaitDelay when starting the upstream
// process. Defaults to cmdWaitDelay; tests override it to keep the
// pipe-close backstop from dominating their runtime.
waitDelay time.Duration
runCh chan runReq
stopCh chan stopReq
waitReadyCh chan waitReadyReq
// current ProcessState. Written only by run(); read by State() via atomic load.
state atomic.Value
// stores the active reverse-proxy handler when the process is running.
// Written only by run(); read by ServeHTTP via atomic load.
handler atomic.Pointer[http.HandlerFunc]
lastUse atomic.Int64 // unix nano timestamp of last ServeHTTP completion
inflight atomic.Int64 // current in-flight ServeHTTP calls
}
var _ Process = (*ProcessCommand)(nil)
func New(
parentCtx context.Context,
id string,
conf config.ModelConfig,
processLogger *logmon.Monitor,
proxyLogger *logmon.Monitor,
) (*ProcessCommand, error) {
p := &ProcessCommand{
id: id,
config: conf,
parentCtx: parentCtx,
processLogger: processLogger,
proxyLogger: proxyLogger,
runCh: make(chan runReq),
stopCh: make(chan stopReq),
waitReadyCh: make(chan waitReadyReq),
waitDelay: cmdWaitDelay,
}
p.state.Store(StateStopped)
go p.run()
return p, nil
}
func (p *ProcessCommand) Logger() *logmon.Monitor { return p.processLogger }
// run is the single-writer goroutine that owns all mutable lifecycle state
// (current ProcessState, the running *exec.Cmd, the active reverse-proxy
// handler, and the list of WaitReady subscribers). Every public method
// (Run / Stop / State / WaitReady) is a thin client that sends a request on
// one of the channels below and waits for a response — this funnels concurrent
// callers through a single serialization point so the state machine never
// observes a race.
func (p *ProcessCommand) run() {
// Mutable state — only read/written from this goroutine. ServeHTTP reads
// p.handler concurrently, which is why handler is an atomic.Pointer.
// p.state mirrors `state` so State() can observe transitions; setState
// writes both.
state := StateStopped
setState := func(s ProcessState) {
old := state
state = s
p.state.Store(s)
if old != s {
event.Emit(shared.ProcessStateChangeEvent{
ProcessName: p.id,
OldState: string(old),
NewState: string(s),
})
}
}
var (
cmd *exec.Cmd
cmdDone <-chan struct{}
cmdCancel context.CancelFunc
readyWaiters []waitReadyReq
// runResp parks the in-flight Run caller's response channel. The
// interface contract is that Run blocks until the process is
// terminated, so we hold this until Stop, parentCtx, or an
// upstream exit unblocks it via respondRun.
runResp chan<- error
)
// notifyWaiters wakes every blocked WaitReady caller with the given result.
// Used on transitions out of StateStarting (ready, failed, aborted, or
// shutdown) — anything that resolves the "is it ready yet?" question.
notifyWaiters := func(err error) {
for _, w := range readyWaiters {
select {
case w.respond <- err:
default:
}
}
readyWaiters = nil
}
// respondRun delivers the final Run result, if a Run caller is parked.
respondRun := func(err error) {
if runResp != nil {
runResp <- err
runResp = nil
}
}
for {
select {
// Shutdown: parent context cancelled. Tear down any running process,
// wake any pending WaitReady callers with an error, then exit the
// goroutine permanently. Subsequent public-method calls will fail
// because parentCtx.Done() unblocks their send-side selects.
case <-p.parentCtx.Done():
// Mark shutdown before killProcess so concurrent State() readers
// stop treating this process as ready while the (possibly slow)
// teardown is in progress.
setState(StateShutdown)
if cmd != nil {
p.handler.Store(nil)
p.killProcess(cmd, cmdCancel, cmdDone, parentCancelGraceTimeout)
cmd = nil
cmdDone = nil
cmdCancel = nil
}
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
respondRun(fmt.Errorf("[%s] shutdown", p.id))
return
// Upstream exited on its own (not via Stop). Drop handler state,
// transition to Stopped, and unblock the parked Run caller.
// cmdDone is nil while no process is running, so this case is
// dormant outside of StateReady.
case <-cmdDone:
if cmdCancel != nil {
cmdCancel()
}
cmd = nil
cmdDone = nil
cmdCancel = nil
p.handler.Store(nil)
setState(StateStopped)
respondRun(fmt.Errorf("[%s] upstream exited unexpectedly", p.id))
// WaitReady: if we're already in a terminal-for-this-question state,
// respond immediately; otherwise queue the caller and let a future
// state transition wake them via notifyWaiters.
case req := <-p.waitReadyCh:
switch state {
case StateReady:
req.respond <- nil
case StateShutdown:
req.respond <- fmt.Errorf("[%s] shutdown", p.id)
default:
readyWaiters = append(readyWaiters, req)
}
// Run: start the upstream process. Only valid from StateStopped.
// doStart can take a long time (health-check polling), so it runs in
// a separate goroutine and we wait on resultCh. While waiting we also
// listen for an incoming Stop — that's how callers cancel an in-flight
// start.
case req := <-p.runCh:
if state != StateStopped {
req.respond <- fmt.Errorf("[%s] could not be started in %s state", p.id, state)
continue
}
setState(StateStarting)
startCtx, cancelStart := context.WithCancel(context.Background())
resultCh := make(chan startResult, 1)
go func() {
resultCh <- p.doStart(startCtx, req.timeout)
}()
// pendingStop holds a Stop request that arrived mid-start, so we
// can respond to it AFTER we've finished tearing the start down.
var pendingStop *stopReq
select {
// doStart finished on its own — either successfully (latch
// cmd/handler and move to Ready) or with an error (back to
// Stopped). Either way wake WaitReady subscribers and reply
// to the Run caller.
case res := <-resultCh:
if res.err == nil {
cmd = res.cmd
cmdDone = res.cmdDone
cmdCancel = res.cancel
fn := res.handlerFn
p.handler.Store(&fn)
setState(StateReady)
notifyWaiters(nil)
// Park the Run response — Run blocks until the process
// terminates, so we only fire this when Stop, parentCtx,
// or the upstream exit takes the process down.
runResp = req.respond
// Start TTL goroutine if configured — self-terminates
// when state leaves StateReady.
if p.config.UnloadAfter > 0 {
ttlDuration := time.Duration(p.config.UnloadAfter) * time.Second
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
if p.State() != StateReady {
return
}
if p.inflight.Load() != 0 {
continue
}
if time.Since(time.Unix(0, p.lastUse.Load())) > ttlDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.id, p.config.UnloadAfter)
p.Stop(10 * time.Second)
return
}
}
}()
}
} else {
setState(StateStopped)
notifyWaiters(res.err)
req.respond <- res.err
}
// Stop arrived while doStart was still running. Cancel the
// start context to abort it, then wait for doStart to return.
// If doStart had already crossed the finish line before
// cancellation took effect, it returns a live cmd that we
// must kill ourselves. The Run caller gets ErrAbort; the Stop
// caller is parked in pendingStop and answered below.
case stop := <-p.stopCh:
cancelStart()
res := <-resultCh
if res.cmd != nil {
p.killProcess(res.cmd, res.cancel, res.cmdDone, stop.timeout)
}
setState(StateStopped)
notifyWaiters(ErrStartAborted)
req.respond <- ErrStartAborted
pendingStop = &stop
// Parent context cancelled (e.g. config reload) while doStart
// was still running. Stop() returns early when parentCtx is
// done and never sends on stopCh, so we must handle shutdown
// here to avoid leaving doStart running indefinitely.
case <-p.parentCtx.Done():
cancelStart()
// Mark shutdown before tearing the process down: killProcess
// may block (e.g. taskkill on Windows is slow to spawn), and
// callers observing State() should see StateShutdown promptly
// rather than a stale StateStarting.
setState(StateShutdown)
res := <-resultCh
if res.cmd != nil {
p.killProcess(res.cmd, res.cancel, res.cmdDone, parentCancelGraceTimeout)
}
notifyWaiters(fmt.Errorf("[%s] shutdown", p.id))
respondRun(fmt.Errorf("[%s] shutdown", p.id))
return
}
// cancelStart is idempotent; calling it again here ensures the
// context is released even on the success path (govet leak check).
cancelStart()
if pendingStop != nil {
pendingStop.respond <- nil
}
// Stop: tear down a running process.
case stop := <-p.stopCh:
if cmd != nil {
setState(StateStopping)
p.killProcess(cmd, cmdCancel, cmdDone, stop.timeout)
cmd = nil
cmdDone = nil
cmdCancel = nil
p.handler.Store(nil)
}
// Stop is a no-op (and not an error) when already Stopped — this
// is what makes it idempotent for callers that don't track state.
setState(StateStopped)
respondRun(nil)
stop.respond <- nil
}
}
}
func (p *ProcessCommand) doStart(startCtx context.Context, healthCheckTimeout time.Duration) startResult {
if p.config.Proxy == "" {
return startResult{err: fmt.Errorf("upstream proxy missing")}
}
args, err := p.config.SanitizedCommand()
if err != nil {
return startResult{err: fmt.Errorf("unable to get sanitized command: %w", err)}
}
proxyURL, err := url.Parse(p.config.Proxy)
if err != nil {
return startResult{err: fmt.Errorf("invalid proxy URL %q: %w", p.config.Proxy, err)}
}
reverseProxy := httputil.NewSingleHostReverseProxy(proxyURL)
reverseProxy.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(p.config.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(p.config.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(p.config.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(p.config.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(p.config.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(p.config.Timeouts.IdleConn) * time.Second,
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
// httputil.ReverseProxy panics with http.ErrAbortHandler when the upstream
// disconnects after response headers have been sent. Recover here so the
// streaming termination is treated as a normal client/upstream disconnect.
// see: https://github.com/golang/go/issues/23643
handlerFn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
if rec == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from upstream disconnection during streaming", p.id)
} else {
p.proxyLogger.Warnf("<%s> recovered from panic: %v", p.id, rec)
}
}
}()
reverseProxy.ServeHTTP(w, r)
})
// cmdCtx + cmd.Cancel are wired as a safety net: if the context is ever
// cancelled while the process is alive, cmd.Cancel sends SIGTERM / CmdStop
// and the runtime escalates to SIGKILL after cmd.WaitDelay. In the normal
// teardown path killProcess sends the stop signal directly instead, so
// cmd.WaitDelay only acts as the inherited-pipe backstop measured from
// process exit (see killProcess).
cmdCtx, cmdCancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(cmdCtx, args[0], args[1:]...)
cmd.Stderr = p.processLogger
cmd.Stdout = p.processLogger
cmd.Env = append(cmd.Environ(), p.config.Env...)
cmd.Cancel = func() error { return p.sendStopSignal(cmd) }
cmd.WaitDelay = p.waitDelay
setProcAttributes(cmd)
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.id, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
cmdDone := make(chan struct{})
if err := cmd.Start(); err != nil {
cmdCancel()
return startResult{err: fmt.Errorf("failed to start command '%s': %w", strings.Join(args, " "), err)}
}
go func() {
waitErr := cmd.Wait()
switch st := p.State(); {
case waitErr == nil:
p.proxyLogger.Debugf("<%s> process exited cleanly", p.id)
case st == StateStopping || st == StateShutdown:
// Expected: we force-terminated the process. A forced kill exits
// the child with a non-zero code (e.g. taskkill /f on Windows
// yields exit status 1), so this is not an error.
p.proxyLogger.Debugf("<%s> process stopped by llama-swap: %v", p.id, waitErr)
default:
if exitErr, ok := waitErr.(*exec.ExitError); ok {
p.proxyLogger.Debugf("<%s> process exited: code=%d, err=%v", p.id, exitErr.ExitCode(), waitErr)
} else {
p.proxyLogger.Debugf("<%s> process exited with error: %v", p.id, waitErr)
}
}
close(cmdDone)
}()
abort := func(err error) startResult {
p.killProcess(cmd, cmdCancel, cmdDone, 5*time.Second)
return startResult{err: err}
}
prematureExit := func() startResult {
cmdCancel()
return startResult{err: fmt.Errorf("upstream command exited prematurely")}
}
if startCtx.Err() != nil {
return abort(ErrStartAborted)
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
}
// Wait 250ms for the command to start up before health checking
select {
case <-startCtx.Done():
return abort(ErrStartAborted)
case <-time.After(250 * time.Millisecond):
}
deadline := time.Now().Add(healthCheckTimeout)
for {
select {
case <-startCtx.Done():
return abort(ErrStartAborted)
case <-cmdDone:
return prematureExit()
default:
}
if time.Now().After(deadline) {
return abort(fmt.Errorf("health check timed out after %v", healthCheckTimeout))
}
req, _ := http.NewRequestWithContext(startCtx, "GET", p.config.CheckEndpoint, nil)
rr := httptest.NewRecorder()
reverseProxy.ServeHTTP(rr, req)
resp := rr.Result()
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
p.proxyLogger.Infof("<%s> Health check passed on %s%s", p.id, p.config.Proxy, p.config.CheckEndpoint)
break
} else if startCtx.Err() != nil {
return abort(ErrStartAborted)
}
select {
case <-startCtx.Done():
return abort(ErrStartAborted)
case <-cmdDone:
return prematureExit()
case <-time.After(time.Second):
}
}
return startResult{cmd: cmd, cmdDone: cmdDone, cancel: cmdCancel, handlerFn: handlerFn}
}
// sendStopSignal runs the configured CmdStop (if any) or sends SIGTERM to
// the upstream process. Wired up as cmd.Cancel so it fires whenever the
// cmd's context is cancelled.
func (p *ProcessCommand) sendStopSignal(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
p.processLogger.Debugf("<%s> sendStopSignal() called with nil cmd or process, nothing to stop", p.id)
return nil
}
pid := cmd.Process.Pid
if p.config.CmdStop != "" {
p.processLogger.Debugf("<%s> sendStopSignal() using CmdStop %q for pid %d", p.id, p.config.CmdStop, pid)
stopArgs, err := config.SanitizeCommand(
strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", pid)),
)
if err == nil {
p.processLogger.Debugf("<%s> sendStopSignal() running stop command: %s", p.id, strings.Join(stopArgs, " "))
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Env = cmd.Env
setProcAttributes(stopCmd)
runErr := stopCmd.Run()
if runErr != nil {
p.processLogger.Errorf("<%s> sendStopSignal() stop command failed: %v", p.id, runErr)
} else {
p.processLogger.Debugf("<%s> sendStopSignal() stop command completed for pid %d", p.id, pid)
}
return runErr
}
// fall through to SIGTERM if sanitize failed
p.processLogger.Errorf("<%s> sendStopSignal() failed to sanitize CmdStop %q: %v, falling back to terminateProcessTree", p.id, p.config.CmdStop, err)
}
// On Unix this SIGTERMs the whole process group so a forked grandchild
// (e.g. a shell wrapper that backgrounds the real binary) is taken down
// with the parent rather than orphaned.
p.processLogger.Debugf("<%s> sendStopSignal() no CmdStop configured, calling terminateProcessTree for pid %d", p.id, pid)
termErr := terminateProcessTree(cmd)
if termErr != nil {
p.processLogger.Errorf("<%s> sendStopSignal() terminateProcessTree failed for pid %d: %v", p.id, pid, termErr)
}
return termErr
}
// killProcess terminates the upstream process. The flow:
//
// 1. Send the graceful stop signal (CmdStop / SIGTERM) directly — NOT by
// cancelling cmdCtx. Cancelling the context would start cmd.WaitDelay
// immediately, which force-kills the process WaitDelay after the signal
// and would silently cap gracefulTimeout at WaitDelay whenever
// gracefulTimeout is the longer of the two.
// 2. We wait up to gracefulTimeout for the process to exit on its own.
// 3. If still alive, we SIGKILL the process group directly (Unix) so any
// forked descendant is force-terminated alongside the parent.
// 4. We wait on cmdDone. cmd.WaitDelay (set when the cmd was built) is the
// critical backstop here: once the process exits, if a forked grandchild
// inherited the stdout/stderr pipes and is still holding them, the runtime
// force-closes the pipes WaitDelay after the exit and cmd.Wait() unblocks.
// Because we never cancelled the context, that WaitDelay timer measures
// from process exit (see os/exec awaitGoroutines), not from this call.
// Without WaitDelay this select would hang forever (the v219 bug).
//
// cancel() is still invoked (deferred) to release the context, but only after
// the process has exited and os/exec's ctx watcher has already torn down, so it
// never re-fires cmd.Cancel.
func (p *ProcessCommand) killProcess(cmd *exec.Cmd, cancel context.CancelFunc, cmdDone <-chan struct{}, gracefulTimeout time.Duration) {
if cancel == nil {
return
}
defer cancel()
// Deliver CmdStop / SIGTERM in a goroutine so a slow or hanging CmdStop
// cannot block the run() goroutine; the gracefulTimeout + Process.Kill
// path below still guarantees teardown.
if cmd != nil {
go func() {
p.proxyLogger.Debugf("[%s] sending stop signal with timeout %v", p.id, gracefulTimeout)
if err := p.sendStopSignal(cmd); err != nil {
p.proxyLogger.Warnf("[%s] stop signal failed: %v", p.id, err)
}
}()
}
timer := time.NewTimer(gracefulTimeout)
defer timer.Stop()
select {
case <-cmdDone:
return
case <-timer.C:
}
if cmd != nil {
// SIGKILL the whole process group on Unix so any descendant that
// ignored or outlived the graceful signal is force-terminated too.
_ = killProcessTree(cmd)
}
<-cmdDone
}
func (p *ProcessCommand) ID() string {
return p.id
}
func (p *ProcessCommand) Run(timeout time.Duration) error {
req := runReq{
timeout: timeout,
respond: make(chan error, 1),
}
select {
case p.runCh <- req:
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
select {
case err := <-req.respond:
return err
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
}
func (p *ProcessCommand) WaitReady(ctx context.Context) error {
req := waitReadyReq{respond: make(chan error, 1)}
select {
case p.waitReadyCh <- req:
case <-ctx.Done():
return ctx.Err()
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
select {
case err := <-req.respond:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (p *ProcessCommand) Stop(timeout time.Duration) error {
req := stopReq{
timeout: timeout,
respond: make(chan error, 1),
}
select {
case p.stopCh <- req:
case <-p.parentCtx.Done():
return fmt.Errorf("[%s] shutdown", p.id)
}
return <-req.respond
}
func (p *ProcessCommand) State() ProcessState {
if s, ok := p.state.Load().(ProcessState); ok {
return s
}
return StateStopped
}
func (p *ProcessCommand) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fn := p.handler.Load()
if fn == nil {
http.Error(w, fmt.Sprintf("llama-swap-error: [%s] process is not ready", p.id), http.StatusServiceUnavailable)
return
}
p.inflight.Add(1)
defer func() {
p.lastUse.Store(time.Now().UnixNano())
p.inflight.Add(-1)
}()
(*fn)(w, r)
}
@@ -0,0 +1,262 @@
//go:build !windows
package process
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
)
// TestProcessCommand_StopForkingWrapper is a regression for the bug reported
// against v219 where Stop would hang indefinitely when the upstream command
// is a shell wrapper that forks the real binary (e.g. `#!/bin/bash` then
// `"$@"`). After SIGTERM the wrapper dies but the grandchild inherits the
// stdout/stderr pipes; cmd.Wait() blocks waiting for the pipe-copy goroutine
// to drain EOF, which never happens while the grandchild holds the fds.
//
// The fix is cmd.WaitDelay (combined with exec.CommandContext + cmd.Cancel),
// which causes the runtime to force-close the pipes after the delay so
// cmd.Wait() — and therefore Stop — returns.
func TestProcessCommand_StopForkingWrapper(t *testing.T) {
skipIfNoSimpleResponder(t)
port := getFreePort(t)
dir := t.TempDir()
pidFile := filepath.Join(dir, "child.pid")
// Wrapper script: backgrounds the child (which inherits stdout/stderr),
// records its PID for cleanup, then waits. When SIGTERM hits bash it
// dies without forwarding the signal; the grandchild keeps running and
// keeps the inherited pipe fds open. This is the scenario reported in
// the v219 regression.
wrapper := filepath.Join(dir, "wrapper.sh")
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
simpleResponderPath, port, pidFile)
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
t.Cleanup(func() { killChildFromPidFile(pidFile) })
p := newProcessCommand(t, config.ModelConfig{
Cmd: wrapper,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
// Shrink the pipe-close backstop so the test doesn't sit at the
// production default (10s). Must be set before Run() so doStart picks
// it up when building the cmd.
const testWaitDelay = 250 * time.Millisecond
p.waitDelay = testWaitDelay
runErr := runAsync(t, p)
// Stop must return within a bounded time even though the grandchild
// is still holding the pipe open. Budget is generous on top of
// testWaitDelay to absorb scheduling jitter on slow CI runners; the
// pre-fix behaviour was an unbounded hang, so any reasonable cap
// distinguishes pass from fail.
stopReturned := make(chan error, 1)
stopStart := time.Now()
go func() { stopReturned <- p.Stop(testStopTimeout) }()
const stopBudget = testWaitDelay + 2*time.Second
select {
case err := <-stopReturned:
if err != nil {
t.Fatalf("Stop: %v", err)
}
t.Logf("Stop returned in %v", time.Since(stopStart))
case <-time.After(stopBudget):
t.Fatalf("Stop did not return within %v — cmd.Wait() likely hung on inherited pipe", stopBudget)
}
if got := p.State(); got != StateStopped {
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Errorf("Run did not return after Stop")
}
}
// TestProcessCommand_StopHonorsGracefulTimeout is a regression for the bug
// where cmd.WaitDelay capped the graceful shutdown window. killProcess used to
// cancel the cmd context to deliver SIGTERM, which starts cmd.WaitDelay
// immediately; a process whose SIGTERM handler needs longer than WaitDelay to
// finish was force-killed early even though Stop was given a much longer
// timeout. The fix sends the signal directly so WaitDelay measures from process
// exit (its inherited-pipe backstop role), leaving the graceful window to the
// caller's Stop timeout.
func TestProcessCommand_StopHonorsGracefulTimeout(t *testing.T) {
dir := t.TempDir()
marker := filepath.Join(dir, "graceful.done")
ready := filepath.Join(dir, "trap.ready")
// On SIGTERM, sleep past the (short) WaitDelay, then write the marker and
// exit cleanly. If WaitDelay still drove the kill, bash would be SIGKILLed
// mid-handler and the marker would never be written. The ready file is
// written only after the trap is installed so the test does not race
// SIGTERM ahead of it (CheckEndpoint:none marks ready before bash runs).
script := filepath.Join(dir, "graceful.sh")
body := fmt.Sprintf(
"#!/bin/bash\ncleanup() { sleep 0.6; echo done > %q; exit 0; }\ntrap cleanup SIGTERM\necho ready > %q\nwhile true; do sleep 0.1; done\n",
marker, ready,
)
if err := os.WriteFile(script, []byte(body), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
p := newProcessCommand(t, config.ModelConfig{
Cmd: script,
Proxy: "http://127.0.0.1:1", // unused: health check disabled
CheckEndpoint: "none",
})
// WaitDelay shorter than the handler's 0.6s sleep, and far shorter than the
// Stop timeout below — this is the window the old code mis-killed in.
p.waitDelay = 200 * time.Millisecond
runErr := runAsync(t, p)
// Wait until the trap is installed before stopping.
trapDeadline := time.Now().Add(2 * time.Second)
for {
if _, err := os.Stat(ready); err == nil {
break
}
if time.Now().After(trapDeadline) {
t.Fatalf("script did not install SIGTERM trap in time")
}
time.Sleep(10 * time.Millisecond)
}
stopStart := time.Now()
if err := p.Stop(5 * time.Second); err != nil {
t.Fatalf("Stop: %v", err)
}
elapsed := time.Since(stopStart)
// The handler must have run to completion (marker written) rather than
// being force-killed at waitDelay.
if _, err := os.Stat(marker); err != nil {
t.Fatalf("graceful handler did not complete (marker missing): %v", err)
}
// And Stop must have waited for the handler (>~0.6s), not returned at the
// 200ms waitDelay.
if elapsed < 500*time.Millisecond {
t.Fatalf("Stop returned in %v — process was killed before its graceful handler finished", elapsed)
}
if got := p.State(); got != StateStopped {
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Errorf("Run did not return after Stop")
}
}
// TestProcessCommand_StopReapsForkedGrandchild verifies that stopping a forking
// wrapper takes down the backgrounded grandchild too, rather than leaving it as
// an orphan. The fix is Setpgid (runtime_unix.go): the wrapper leads its own
// process group, so the stop signal is delivered to the whole group via the
// negative PID and reaches the grandchild the wrapper never reaped.
func TestProcessCommand_StopReapsForkedGrandchild(t *testing.T) {
skipIfNoSimpleResponder(t)
port := getFreePort(t)
dir := t.TempDir()
pidFile := filepath.Join(dir, "child.pid")
wrapper := filepath.Join(dir, "wrapper.sh")
script := fmt.Sprintf("#!/bin/bash\n%q -port %d -silent &\necho $! > %q\nwait\n",
simpleResponderPath, port, pidFile)
if err := os.WriteFile(wrapper, []byte(script), 0o755); err != nil {
t.Fatalf("WriteFile: %v", err)
}
t.Cleanup(func() { killChildFromPidFile(pidFile) })
p := newProcessCommand(t, config.ModelConfig{
Cmd: wrapper,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
runErr := runAsync(t, p)
// Read the grandchild PID the wrapper recorded.
var childPID int
deadline := time.Now().Add(2 * time.Second)
for {
data, err := os.ReadFile(pidFile)
if err == nil {
if pid, perr := strconv.Atoi(strings.TrimSpace(string(data))); perr == nil && pid > 0 {
childPID = pid
break
}
}
if time.Now().After(deadline) {
t.Fatalf("wrapper did not record grandchild PID")
}
time.Sleep(10 * time.Millisecond)
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop: %v", err)
}
// After Stop the grandchild must be gone. Signal 0 probes liveness without
// actually sending a signal; give it a brief window to exit after the
// group SIGTERM.
proc, err := os.FindProcess(childPID)
if err != nil {
t.Fatalf("FindProcess: %v", err)
}
gone := false
for i := 0; i < 100; i++ {
if err := proc.Signal(syscall.Signal(0)); err != nil {
gone = true
break
}
time.Sleep(10 * time.Millisecond)
}
if !gone {
t.Errorf("grandchild PID %d still alive after Stop — process group was not reaped", childPID)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Errorf("Run did not return after Stop")
}
}
// killChildFromPidFile reads a PID written by the wrapper script and SIGKILLs
// it so leaked orphans don't accumulate between test runs. Best-effort.
func killChildFromPidFile(pidFile string) {
data, err := os.ReadFile(pidFile)
if err != nil {
return
}
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil || pid <= 0 {
return
}
proc, err := os.FindProcess(pid)
if err != nil {
return
}
_ = proc.Kill()
}
+646
View File
@@ -0,0 +1,646 @@
package process
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
const (
testStartTimeout = 3 * time.Second
testStopTimeout = 2 * time.Second
testReturnTimeout = 1 * time.Second
testPollInterval = 20 * time.Millisecond
testLogPollInterval = 10 * time.Millisecond
)
func newProcessCommand(t *testing.T, conf config.ModelConfig) *ProcessCommand {
t.Helper()
logger := logmon.NewWriter(io.Discard)
p, err := New(context.Background(), t.Name(), conf, logger, logger)
if err != nil {
t.Fatalf("New: %v", err)
}
return p
}
// runAsync starts Run in a goroutine and waits until the process is ready,
// matching the new interface contract where Run blocks until the process is
// terminated. Returns a channel that delivers Run's eventual error.
func runAsync(t *testing.T, p *ProcessCommand) <-chan error {
t.Helper()
ch := make(chan error, 1)
go func() { ch <- p.Run(testStartTimeout) }()
ctx, cancel := context.WithTimeout(context.Background(), testStartTimeout)
defer cancel()
if err := p.WaitReady(ctx); err != nil {
t.Fatalf("WaitReady: %v", err)
}
return ch
}
func TestProcessCommand_StartStop(t *testing.T) {
skipIfNoSimpleResponder(t)
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
t.Cleanup(func() { p.Stop(testStopTimeout) })
req := httptest.NewRequest("GET", "/test", nil)
// before start: no handler
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("before start: expected 503, got %d", rr.Code)
}
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
t.Errorf("before start: expected body to contain %q, got %q", "llama-swap-error", body)
}
runErr := runAsync(t, p)
if got := p.State(); got != StateReady {
t.Errorf("after Run: expected state %s, got %s", StateReady, got)
}
rr = httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("after Run: expected 200, got %d", rr.Code)
}
if body := rr.Body.String(); body != "hello" {
t.Errorf("expected body %q, got %q", "hello", body)
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
if got := p.State(); got != StateStopped {
t.Errorf("after Stop: expected state %s, got %s", StateStopped, got)
}
select {
case err := <-runErr:
if err != nil {
t.Errorf("Run() after Stop: expected nil, got %v", err)
}
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
// after stop: handler cleared
rr = httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("after stop: expected 503, got %d", rr.Code)
}
if body := rr.Body.String(); !strings.Contains(body, "llama-swap-error") {
t.Errorf("after stop: expected body to contain %q, got %q", "llama-swap-error", body)
}
}
func TestProcessCommand_Run_Idempotent(t *testing.T) {
skipIfNoSimpleResponder(t)
cmd, port := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
t.Cleanup(func() { p.Stop(testStopTimeout) })
runErr := runAsync(t, p)
if err := p.Run(testStartTimeout); err == nil {
t.Error("second Run() while running: expected error, got nil")
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
}
func TestProcessCommand_Stop_Idempotent(t *testing.T) {
skipIfNoSimpleResponder(t)
cmd, port := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() before Run(): %v", err)
}
runErr := runAsync(t, p)
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("first Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("second Stop() error: %v", err)
}
}
// TestProcessCommand_StopCancelsRun verifies that a Stop sent while Run is
// executing its health-check loop returns ErrAbort to the Run caller.
//
// A blocking mock HTTP server is used as the proxy so the test can deterministically
// know when doStart is inside the health-check loop before issuing Stop.
func TestProcessCommand_StopCancelsRun(t *testing.T) {
skipIfNoSimpleResponder(t)
healthCheckStarted := make(chan struct{}, 1)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Signal that a health check is in-flight, then block until the client
// cancels (which happens when Stop cancels the start context).
select {
case healthCheckStarted <- struct{}{}:
default:
}
<-r.Context().Done()
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
}))
defer mock.Close()
// simple-responder is the real process; health checks go to the blocking mock.
cmd, _ := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 30,
})
runErrCh := make(chan error, 1)
go func() {
runErrCh <- p.Run(testStartTimeout)
}()
// Block until doStart is actually performing a health check, guaranteeing
// that Run is in-flight when Stop is called.
<-healthCheckStarted
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
if err := <-runErrCh; !errors.Is(err, ErrStartAborted) {
t.Errorf("expected ErrStartAborted from Run, got %v", err)
}
}
// TestProcessCommand_ParentCtxCancelDuringStart verifies that cancelling the
// parent context while doStart is health-checking causes the process to
// transition to StateShutdown promptly, not wait for the health-check timeout.
//
// This is the config-reload race: Stop() returns early when parentCtx is
// already done and never writes to stopCh, so without a parentCtx.Done()
// case in the inner select, the process would keep loading indefinitely.
func TestProcessCommand_ParentCtxCancelDuringStart(t *testing.T) {
skipIfNoSimpleResponder(t)
healthCheckStarted := make(chan struct{}, 1)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case healthCheckStarted <- struct{}{}:
default:
}
<-r.Context().Done()
http.Error(w, "mock cancelled", http.StatusServiceUnavailable)
}))
defer mock.Close()
parentCtx, cancelParent := context.WithCancel(context.Background())
logger := logmon.NewWriter(io.Discard)
cmd, _ := simpleResponderCmd(t, "-silent")
p, err := New(parentCtx, t.Name(), config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 60,
}, logger, logger)
if err != nil {
t.Fatalf("New: %v", err)
}
runErrCh := make(chan error, 1)
go func() { runErrCh <- p.Run(60 * time.Second) }()
<-healthCheckStarted
// Cancel parent context to simulate a config reload tearing down the old server.
cancelParent()
select {
case err := <-runErrCh:
if !strings.Contains(err.Error(), "shutdown") {
t.Errorf("Run error = %v, want shutdown error", err)
}
case <-time.After(5 * time.Second):
t.Fatal("process did not shut down within 5s after parent context cancel during start")
}
// Run() may return before the run() goroutine writes StateShutdown;
// poll briefly to avoid a spurious race in the assertion.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if p.State() == StateShutdown {
break
}
time.Sleep(testPollInterval)
}
if got := p.State(); got != StateShutdown {
t.Errorf("after cancel: expected StateShutdown, got %s", got)
}
}
// TestProcessCommand_RunStopCycle runs several sequential start/stop pairs on
// fresh processes to confirm they are reusable.
func TestProcessCommand_RunStopCycle(t *testing.T) {
skipIfNoSimpleResponder(t)
for i := range 3 {
cmd, port := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
runErr := runAsync(t, p)
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("cycle %d: expected 200 from /health, got %d", i, rr.Code)
}
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("cycle %d Stop() error: %v", i, err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatalf("cycle %d: Run() did not return after Stop", i)
}
}
}
// TestProcessCommand_ReverseProxyPanicIsRecovered drives the full proxy path:
// the upstream responds healthy on /health (so Run completes), then on the
// actual proxied request it hijacks the connection and closes it mid-body.
// That upstream EOF makes httputil.ReverseProxy.copyResponse return an error,
// which panics with http.ErrAbortHandler — the wrapped handlerFn must recover
// and log the disconnect.
//
// Requests are issued through an httptest.NewServer wrapping the process so
// the panic actually fires (httputil only panics on copy errors when the
// request carries http.ServerContextKey, which a real server sets).
//
// see: https://github.com/golang/go/issues/23643
func TestProcessCommand_ReverseProxyPanicIsRecovered(t *testing.T) {
skipIfNoSimpleResponder(t)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(http.StatusOK)
return
}
// Send a Content-Length that promises 100 bytes, deliver only a few,
// then slam the connection shut. The reverse proxy will see EOF
// before the body is fully copied and panic with ErrAbortHandler.
hj, ok := w.(http.Hijacker)
if !ok {
t.Errorf("upstream: hijack not supported")
return
}
conn, _, err := hj.Hijack()
if err != nil {
t.Errorf("upstream: hijack: %v", err)
return
}
_, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/plain\r\n\r\npartial"))
_ = conn.Close()
}))
t.Cleanup(upstream.Close)
// Capture proxy log output so we can assert the recover message was
// emitted by handlerFn.
logBuf := &syncBuffer{}
proxyLogger := logmon.NewWriter(logBuf)
procLogger := logmon.NewWriter(io.Discard)
cmd, _ := simpleResponderCmd(t, "-silent")
p, err := New(context.Background(), t.Name(), config.ModelConfig{
Cmd: cmd,
Proxy: upstream.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
}, procLogger, proxyLogger)
if err != nil {
t.Fatalf("New: %v", err)
}
t.Cleanup(func() { p.Stop(testStopTimeout) })
_ = runAsync(t, p)
// Wrap p in an httptest server so requests get http.ServerContextKey
// automatically — that is what makes httputil.ReverseProxy raise the panic.
front := httptest.NewServer(p)
t.Cleanup(front.Close)
resp, err := http.Get(front.URL + "/disconnect")
if err == nil {
resp.Body.Close()
}
const want = "recovered from upstream disconnection"
deadline := time.Now().Add(testReturnTimeout)
for time.Now().Before(deadline) {
if strings.Contains(logBuf.String(), want) {
return
}
time.Sleep(testLogPollInterval)
}
t.Errorf("expected proxy log to contain %q; got:\n%s", want, logBuf.String())
}
// syncBuffer is a concurrent-safe bytes.Buffer for capturing logmon output.
type syncBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}
func (b *syncBuffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}
func (b *syncBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}
// TestProcessCommand_TTL_StopsAfterIdle verifies that a process with a TTL
// automatically stops itself after the idle timeout has elapsed following its
// last request.
func TestProcessCommand_TTL_StopsAfterIdle(t *testing.T) {
skipIfNoSimpleResponder(t)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(mock.Close)
cmd, _ := simpleResponderCmd(t, "-silent")
cfg := config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
UnloadAfter: 1, // 1-second TTL
}
if runtime.GOOS == "windows" {
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
}
p := newProcessCommand(t, cfg)
runErr := runAsync(t, p)
defer func() {
if p.State() == StateReady {
p.Stop(testStopTimeout)
}
}()
if got := p.State(); got != StateReady {
t.Fatalf("expected StateReady, got %s", got)
}
// Make one request to prime the last-use timestamp.
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 after request, got %d", rr.Code)
}
// Wait for the TTL goroutine to fire and the process to fully stop.
// Poll for StateStopped directly to avoid racing the StateStopping
// intermediate state that sits between StateReady and StateStopped.
deadline := time.Now().Add(5 * time.Second)
for p.State() != StateStopped && time.Now().Before(deadline) {
time.Sleep(testPollInterval)
}
if got := p.State(); got != StateStopped {
t.Fatalf("TTL did not stop process; state is %s (expected %s)", got, StateStopped)
}
// Run() should have returned nil (clean stop from TTL).
select {
case err := <-runErr:
if err != nil {
t.Errorf("Run() after TTL stop: expected nil, got %v", err)
}
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after TTL-induced stop")
}
}
// TestProcessCommand_TTL_ResetsOnRequest verifies that inflight requests
// prevent the TTL goroutine from stopping the process, and that the TTL timer
// resets after each request completes.
func TestProcessCommand_TTL_ResetsOnRequest(t *testing.T) {
skipIfNoSimpleResponder(t)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(mock.Close)
cmd, _ := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
UnloadAfter: 1, // 1-second TTL
})
runErr := runAsync(t, p)
defer func() {
if p.State() == StateReady {
p.Stop(testStopTimeout)
}
}()
// Keep sending requests for 1.5s — past the 1s TTL — and verify
// the process never stops while traffic is flowing.
stopAt := time.Now().Add(1500 * time.Millisecond)
for time.Now().Before(stopAt) {
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if p.State() != StateReady {
t.Fatalf("process was stopped during active traffic (state=%s)", p.State())
}
time.Sleep(10 * time.Millisecond)
}
if got := p.State(); got != StateReady {
t.Fatalf("expected StateReady while traffic was active, got %s", got)
}
// Now stop manually to clean up.
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
}
// TestProcessCommand_TTL_ZeroDisables verifies that UnloadAfter=0 does not
// spawn a TTL goroutine — the process stays ready until explicitly stopped.
func TestProcessCommand_TTL_ZeroDisables(t *testing.T) {
skipIfNoSimpleResponder(t)
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(mock.Close)
cmd, _ := simpleResponderCmd(t, "-silent")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: mock.URL,
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
UnloadAfter: 0, // disabled
})
runErr := runAsync(t, p)
defer func() {
if p.State() == StateReady {
p.Stop(testStopTimeout)
}
}()
if got := p.State(); got != StateReady {
t.Fatalf("expected StateReady, got %s", got)
}
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
p.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 after request, got %d", rr.Code)
}
// No TTL goroutine is spawned when UnloadAfter=0, so a brief sleep is
// enough to confirm the process remains ready.
time.Sleep(100 * time.Millisecond)
if got := p.State(); got != StateReady {
t.Fatalf("process was stopped unexpectedly (state=%s) with TTL=0", got)
}
// Cleanly stop.
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop() error: %v", err)
}
select {
case <-runErr:
case <-time.After(testReturnTimeout):
t.Fatal("Run() did not return after Stop")
}
}
// TestProcessCommand_ConcurrentRunStop launches many concurrent run/stop racing
// pairs to exercise the race detector and verify no deadlocks occur.
func TestProcessCommand_ConcurrentRunStop(t *testing.T) {
skipIfNoSimpleResponder(t)
for range 10 {
cmd, port := simpleResponderCmd(t, "-silent")
cfg := config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
}
if runtime.GOOS == "windows" {
cfg.CmdStop = "taskkill /f /t /pid ${PID}"
}
p := newProcessCommand(t, cfg)
runDone := make(chan struct{})
go func() {
defer close(runDone)
p.Run(testStartTimeout) //nolint: errcheck — one goroutine wins the race
}()
go func() {
p.Stop(testStopTimeout) //nolint: errcheck
}()
// Backstop: the racing Stop may have arrived before Run got on the
// channel (making it a no-op), so keep stopping until Run unblocks.
deadline := time.After(testStartTimeout)
for done := false; !done; {
select {
case <-runDone:
done = true
case <-deadline:
t.Fatal("Run did not return")
case <-time.After(testPollInterval):
p.Stop(testStopTimeout) //nolint: errcheck
}
}
}
}
+82
View File
@@ -0,0 +1,82 @@
package process
import (
"fmt"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/shared"
)
func TestProcessCommand_EmitsStateChangeEvents(t *testing.T) {
skipIfNoSimpleResponder(t)
var mu sync.Mutex
var transitions []shared.ProcessStateChangeEvent
cancel := event.On(func(e shared.ProcessStateChangeEvent) {
if e.ProcessName != t.Name() {
return
}
mu.Lock()
transitions = append(transitions, e)
mu.Unlock()
})
defer cancel()
cmd, port := simpleResponderCmd(t, "-silent", "-respond hello")
p := newProcessCommand(t, config.ModelConfig{
Cmd: cmd,
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
HealthCheckTimeout: 10,
})
runErr := runAsync(t, p)
if err := p.Stop(testStopTimeout); err != nil {
t.Fatalf("Stop: %v", err)
}
<-runErr
// Events are delivered asynchronously; give the dispatcher a moment.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
mu.Lock()
n := len(transitions)
mu.Unlock()
if n >= 4 {
break
}
time.Sleep(testPollInterval)
}
mu.Lock()
defer mu.Unlock()
for _, e := range transitions {
if e.OldState == e.NewState {
t.Errorf("emitted no-op transition: %s -> %s", e.OldState, e.NewState)
}
}
want := []string{
string(StateStopped) + "->" + string(StateStarting),
string(StateStarting) + "->" + string(StateReady),
string(StateReady) + "->" + string(StateStopping),
string(StateStopping) + "->" + string(StateStopped),
}
got := make([]string, len(transitions))
for i, e := range transitions {
got[i] = e.OldState + "->" + e.NewState
}
if len(got) != len(want) {
t.Fatalf("transitions = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("transitions = %v, want %v", got, want)
}
}
}
+44
View File
@@ -0,0 +1,44 @@
//go:build !windows
package process
import (
"os/exec"
"syscall"
)
// setProcAttributes starts the upstream in its own process group (Setpgid) so
// the entire process tree can be signalled at once via its negative PID. This
// is what lets us reap a forked grandchild — e.g. a shell wrapper that
// backgrounds the real binary and exits — instead of leaking it as an orphan
// that holds the inherited stdout/stderr pipes open.
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
// terminateProcessTree sends SIGTERM to the whole process group led by the
// command, giving every process in the tree a chance to shut down gracefully.
func terminateProcessTree(cmd *exec.Cmd) error {
return signalProcessTree(cmd, syscall.SIGTERM)
}
// killProcessTree sends SIGKILL to the whole process group, force-terminating
// every process in the tree.
func killProcessTree(cmd *exec.Cmd) error {
return signalProcessTree(cmd, syscall.SIGKILL)
}
// signalProcessTree signals the process group led by cmd.Process. Because the
// child was started with Setpgid it is its own group leader (pgid == pid), so
// targeting -pid reaches the child and every descendant still in the group.
// Falls back to signalling just the child if the group send fails (e.g. the
// group has already drained), so we never silently skip the signal.
func signalProcessTree(cmd *exec.Cmd, sig syscall.Signal) error {
if cmd == nil || cmd.Process == nil {
return nil
}
if err := syscall.Kill(-cmd.Process.Pid, sig); err != nil {
return cmd.Process.Signal(sig)
}
return nil
}
+53
View File
@@ -0,0 +1,53 @@
//go:build windows
package process
import (
"fmt"
"os/exec"
"syscall"
)
// setProcAttributes sets platform-specific process attributes. CREATE_NO_WINDOW
// keeps the upstream from spawning its own console window.
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
}
}
// terminateProcessTree requests a graceful shutdown of the whole process tree
// rooted at cmd.Process. Windows has no SIGTERM or process-group signalling, so
// we shell out to `taskkill /t`, which walks the child tree by PID — the
// equivalent of signalling a Unix process group. Without /f, taskkill asks the
// processes to close rather than force-killing them.
func terminateProcessTree(cmd *exec.Cmd) error {
return taskkillProcessTree(cmd, false)
}
// killProcessTree force-terminates the whole process tree rooted at cmd.Process
// via `taskkill /f /t`, so any descendant that ignored or outlived the graceful
// request is killed alongside the parent rather than leaked as an orphan.
func killProcessTree(cmd *exec.Cmd) error {
return taskkillProcessTree(cmd, true)
}
// taskkillProcessTree runs taskkill against cmd.Process.Pid. The /t flag
// terminates the process together with any child processes it started, which is
// the Windows analogue of signalling a Unix process group via its negative PID.
// When force is true the /f flag force-kills; otherwise taskkill requests a
// graceful close.
func taskkillProcessTree(cmd *exec.Cmd, force bool) error {
if cmd == nil || cmd.Process == nil {
return nil
}
args := make([]string, 0, 4)
if force {
args = append(args, "/f")
}
args = append(args, "/t", "/pid", fmt.Sprintf("%d", cmd.Process.Pid))
kill := exec.Command("taskkill", args...)
setProcAttributes(kill)
return kill.Run()
}
+7
View File
@@ -0,0 +1,7 @@
//go:build !windows
package process
// SetupTreeCleanup is a no-op on non-Windows platforms, where upstream process
// teardown is handled via process-group signalling (see runtime_unix.go).
func SetupTreeCleanup() error { return nil }
+50
View File
@@ -0,0 +1,50 @@
//go:build windows
package process
import (
"fmt"
"unsafe"
"golang.org/x/sys/windows"
)
// SetupTreeCleanup assigns the current process to a Windows Job Object
// configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. Upstream processes
// spawned afterwards are associated with the same job, so when llama-swap exits
// for any reason — graceful shutdown, a forced second Ctrl+C, or a crash — the
// OS terminates the whole job and reaps every child instead of leaving orphans
// behind. It is the parent-side complement to the per-process teardown in
// runtime_windows.go.
//
// The job handle is intentionally leaked for the lifetime of the process: the
// kill-on-close behaviour fires when the last handle is released, which the OS
// does when the process exits.
func SetupTreeCleanup() error {
job, err := windows.CreateJobObject(nil, nil)
if err != nil {
return fmt.Errorf("CreateJobObject: %w", err)
}
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
},
}
if _, err := windows.SetInformationJobObject(
job,
windows.JobObjectExtendedLimitInformation,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info)),
); err != nil {
windows.CloseHandle(job)
return fmt.Errorf("SetInformationJobObject: %w", err)
}
if err := windows.AssignProcessToJobObject(job, windows.CurrentProcess()); err != nil {
windows.CloseHandle(job)
return fmt.Errorf("AssignProcessToJobObject: %w", err)
}
return nil
}
+39
View File
@@ -0,0 +1,39 @@
package ring
type Buffer[T any] struct {
buf []T
head int
size int
}
func NewBuffer[T any](capacity int) Buffer[T] {
if capacity < 1 {
capacity = 1
}
return Buffer[T]{buf: make([]T, capacity)}
}
// Push adds v, overwriting the oldest entry when the buffer is full.
func (r *Buffer[T]) Push(v T) {
cap := len(r.buf)
if r.size < cap {
r.buf[(r.head+r.size)%cap] = v
r.size++
} else {
r.buf[r.head] = v
r.head = (r.head + 1) % cap
}
}
// Slice returns all entries in insertion order as a new slice.
func (r *Buffer[T]) Slice() []T {
if r.size == 0 {
return nil
}
cap := len(r.buf)
result := make([]T, r.size)
for i := 0; i < r.size; i++ {
result[i] = r.buf[(r.head+i)%cap]
}
return result
}
+44
View File
@@ -0,0 +1,44 @@
package ring
import "testing"
const benchCap = 600 // matches default MaxAge/Every (1min / 100ms)
func BenchmarkBuffer_PushNoWrap(b *testing.B) {
for b.Loop() {
buf := NewBuffer[int](b.N + 1)
for i := range b.N {
buf.Push(i)
}
}
}
func BenchmarkBuffer_PushWrap(b *testing.B) {
buf := NewBuffer[int](benchCap)
b.ResetTimer()
for i := range b.N {
buf.Push(i)
}
}
func BenchmarkBuffer_Slice(b *testing.B) {
buf := NewBuffer[int](benchCap)
for i := range benchCap {
buf.Push(i)
}
b.ResetTimer()
for range b.N {
_ = buf.Slice()
}
}
func BenchmarkBuffer_PushAndSlice(b *testing.B) {
buf := NewBuffer[int](benchCap)
b.ResetTimer()
for i := range b.N {
buf.Push(i)
if i%benchCap == 0 {
_ = buf.Slice()
}
}
}
+65
View File
@@ -0,0 +1,65 @@
package ring
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuffer_EmptySliceIsNil(t *testing.T) {
b := NewBuffer[int](4)
assert.Nil(t, b.Slice())
}
func TestBuffer_PushBelowCapacity(t *testing.T) {
b := NewBuffer[int](4)
b.Push(1)
b.Push(2)
assert.Equal(t, []int{1, 2}, b.Slice())
}
func TestBuffer_PushAtCapacity(t *testing.T) {
b := NewBuffer[int](3)
b.Push(1)
b.Push(2)
b.Push(3)
assert.Equal(t, []int{1, 2, 3}, b.Slice())
}
func TestBuffer_PushOverCapacityEvictsOldest(t *testing.T) {
b := NewBuffer[int](3)
b.Push(1)
b.Push(2)
b.Push(3)
b.Push(4)
assert.Equal(t, []int{2, 3, 4}, b.Slice())
}
func TestBuffer_CapacityOne(t *testing.T) {
b := NewBuffer[int](1)
b.Push(1)
b.Push(2)
assert.Equal(t, []int{2}, b.Slice())
}
func TestBuffer_ZeroCapacityDefaultsToOne(t *testing.T) {
b := NewBuffer[int](0)
b.Push(42)
assert.Equal(t, []int{42}, b.Slice())
}
func TestBuffer_SliceReturnsCopy(t *testing.T) {
b := NewBuffer[int](4)
b.Push(10)
s := b.Slice()
s[0] = 99
assert.Equal(t, []int{10}, b.Slice())
}
func TestBuffer_InsertionOrderPreservedAfterWrap(t *testing.T) {
b := NewBuffer[int](4)
for i := 1; i <= 8; i++ {
b.Push(i)
}
assert.Equal(t, []int{5, 6, 7, 8}, b.Slice())
}
+505
View File
@@ -0,0 +1,505 @@
package router
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
"github.com/mostlygeek/llama-swap/internal/shared"
)
type shutdownReq struct {
timeout time.Duration
respond chan error
}
type unloadReq struct {
targets []string
timeout time.Duration
respond chan struct{}
}
// baseRouter owns the channels, run-loop, and process machinery shared by every
// concrete router. Concrete routers embed *baseRouter and supply a
// scheduler.Swapper describing how eviction sets are decided. baseRouter
// implements scheduler.Effects so the scheduler can call back for side-effects.
type baseRouter struct {
name string
config config.Config
processes map[string]process.Process
logger *logmon.Monitor
schedule scheduler.Scheduler
// shutdownCtx governs the request machinery: cancelling it tells grant()
// and ServeHTTP to stop granting and reject callers. It is deliberately
// separate from procCtx — see procCtx below.
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
// procCtx is the parent context for every managed process and governs
// process lifetime only. handleShutdown stops processes gracefully via
// Stop() and cancels procCtx afterwards, so teardown is never a context
// cancel racing the graceful path (which collapsed the grace to 100ms and
// let the caller return before children were reaped — see process run loop).
procCtx context.Context
procCancel context.CancelFunc
handlerCh chan scheduler.HandlerReq
cancelCh chan scheduler.HandlerReq
shutdownCh chan shutdownReq
unloadCh chan unloadReq
swapDoneCh chan scheduler.SwapDone
serveDoneCh chan scheduler.ServeDoneEvent
runDone chan struct{}
// testProcessed, when non-nil, receives one event after each handlerReq
// or swapDone has been fully processed by run(). Tests use it to wait
// for run() to reach a deterministic state without sleeping. serveDone
// events are intentionally NOT signalled here so test event counts
// remain stable.
testProcessed chan struct{}
}
func newBaseRouter(
name string,
conf config.Config,
processes map[string]process.Process,
logger *logmon.Monitor,
planner scheduler.Swapper,
) (*baseRouter, error) {
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
procCtx, procCancel := context.WithCancel(context.Background())
b := &baseRouter{
name: name,
config: conf,
processes: processes,
logger: logger,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
procCtx: procCtx,
procCancel: procCancel,
handlerCh: make(chan scheduler.HandlerReq),
cancelCh: make(chan scheduler.HandlerReq),
shutdownCh: make(chan shutdownReq),
unloadCh: make(chan unloadReq),
swapDoneCh: make(chan scheduler.SwapDone),
serveDoneCh: make(chan scheduler.ServeDoneEvent),
runDone: make(chan struct{}),
}
sched, err := scheduler.New(conf, name, logger, planner, b)
if err != nil {
return nil, err
}
b.schedule = sched
return b, nil
}
func (b *baseRouter) notifyProcessed() {
if b.testProcessed != nil {
b.testProcessed <- struct{}{}
}
}
func (b *baseRouter) run() {
defer close(b.runDone)
for {
select {
case req := <-b.shutdownCh:
b.handleShutdown(req)
return
case req := <-b.handlerCh:
b.schedule.OnRequest(req)
b.notifyProcessed()
case req := <-b.cancelCh:
b.schedule.OnCancel(req)
b.notifyProcessed()
case req := <-b.unloadCh:
b.schedule.OnUnload(req.targets, req.timeout)
close(req.respond)
b.notifyProcessed()
case ev := <-b.swapDoneCh:
b.schedule.OnSwapDone(ev)
b.notifyProcessed()
case ev := <-b.serveDoneCh:
b.schedule.OnServeDone(ev)
}
}
}
// grant sends a response back to the caller of ServeHTTP and tells us
// whether the caller was still there to receive it.
//
// Each ServeHTTP creates a fresh, UNBUFFERED respond channel and parks in
// a select waiting on it. "Unbuffered" is the important word: a send only
// completes when the other side is actively receiving. So if this send
// succeeds, we know for a fact the caller picked up the response and will
// act on it. If the caller has already given up (its request context was
// cancelled, e.g. the HTTP client disconnected) or the router is shutting
// down, the send never lands, one of the other select cases fires, and we
// report back that the grant did NOT happen.
//
// That distinction matters for in-flight bookkeeping — see GrantServe.
func (b *baseRouter) grant(req scheduler.HandlerReq, resp scheduler.HandlerResp) bool {
select {
case req.Respond <- resp:
return true
case <-req.Ctx.Done():
return false
case <-b.shutdownCtx.Done():
return false
}
}
// ModelState implements scheduler.Effects.
func (b *baseRouter) ModelState(modelID string) (process.ProcessState, bool) {
p, ok := b.processes[modelID]
if !ok {
var zero process.ProcessState
return zero, false
}
return p.State(), true
}
// StartSwap implements scheduler.Effects, launching the swap goroutine.
func (b *baseRouter) StartSwap(modelID string, evict []string) {
go b.doSwap(modelID, evict)
}
// GrantError implements scheduler.Effects.
func (b *baseRouter) GrantError(req scheduler.HandlerReq, err error) {
b.grant(req, scheduler.HandlerResp{Err: err})
}
// GrantServe implements scheduler.Effects. It hands the caller a wrapped
// p.ServeHTTP (via trackedServe) so the run loop hears about the request
// finishing, and reports whether the caller received it. The scheduler bumps
// its in-flight count only on a true return: if grant() returns false the
// caller already walked away and trackedServe will never run, so no matching
// decrement will ever arrive — incrementing would strand the counter at >0 and
// the router would never again be willing to evict this model.
func (b *baseRouter) GrantServe(req scheduler.HandlerReq, modelID string) bool {
p := b.processes[modelID]
return b.grant(req, scheduler.HandlerResp{HandleFunc: b.trackedServe(modelID, p)})
}
// StopProcesses implements scheduler.Effects, stopping the named processes in
// parallel and blocking until all have stopped.
func (b *baseRouter) StopProcesses(timeout time.Duration, ids []string) {
var wg sync.WaitGroup
for _, id := range ids {
p, ok := b.processes[id]
if !ok {
continue
}
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(timeout); err != nil {
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
}
}(id, p)
}
wg.Wait()
}
// trackedServe is the wrapper that closes the loop on in-flight tracking.
// It runs p.ServeHTTP normally; the only added behaviour is a deferred
// send on serveDoneCh after the handler returns. That send is what tells
// the run loop "this model now has one fewer request in flight — go look
// at the queue again, you may be able to start a swap you previously had
// to defer."
//
// The select on shutdownCtx.Done() is a release valve: if the router is
// already shutting down, nobody is reading serveDoneCh, so we drop the
// notification rather than blocking the HTTP goroutine forever.
func (b *baseRouter) trackedServe(modelID string, p process.Process) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
select {
case b.serveDoneCh <- scheduler.ServeDoneEvent{ModelID: modelID}:
case <-b.shutdownCtx.Done():
}
}()
p.ServeHTTP(w, r)
}
}
func (b *baseRouter) doSwap(modelID string, toStop []string) {
timeout := b.healthCheckTimeout()
var wg sync.WaitGroup
for _, mID := range toStop {
wg.Add(1)
go func(p process.Process, id string) {
defer wg.Done()
if err := p.Stop(timeout); err != nil {
b.logger.Warnf("%s: stopping %s failed: %v", b.name, id, err)
}
}(b.processes[mID], mID)
}
wg.Wait()
target := b.processes[modelID]
if target.State() == process.StateStopped {
go func() {
if err := target.Run(timeout); err != nil {
b.logger.Warnf("%s: running %s exited: %v", b.name, modelID, err)
}
}()
}
err := target.WaitReady(b.shutdownCtx)
select {
case b.swapDoneCh <- scheduler.SwapDone{ModelID: modelID, Err: err}:
case <-b.shutdownCtx.Done():
}
}
func (b *baseRouter) handleShutdown(req shutdownReq) {
shutdownErr := fmt.Errorf("%s is shutting down", b.name)
// Cancel shutdownCtx first so any waiter that is currently parked on
// its respond channel can exit via its own shutdownCtx.Done() branch.
// The OnShutdown grants below then either land (waiter happened to receive
// before noticing shutdown) or fall through immediately via grant's
// shutdownCtx case — either way the waiter sees a non-OK response.
// This does NOT touch processes: their lifetime is procCtx, cancelled
// only after the graceful Stop() calls below have reaped them.
b.shutdownFn()
b.schedule.OnShutdown(shutdownErr)
stopTimeout := req.timeout
if stopTimeout <= 0 {
stopTimeout = b.healthCheckTimeout()
}
var wg sync.WaitGroup
for i, p := range b.processes {
wg.Add(1)
go func(id string, p process.Process) {
defer wg.Done()
if err := p.Stop(stopTimeout); err != nil {
b.logger.Warnf("%s failed to stop process %s: %v", b.name, id, err)
}
}(i, p)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
if req.timeout > 0 {
select {
case <-done:
case <-time.After(req.timeout):
<-done
}
} else {
<-done
}
// Every process is stopped (children reaped via Stop()). Cancel procCtx so
// the process run-loop goroutines exit; they are already StateStopped, so
// this is a clean no-op kill rather than a forced teardown.
b.procCancel()
req.respond <- nil
}
func (b *baseRouter) healthCheckTimeout() time.Duration {
t := time.Duration(b.config.HealthCheckTimeout) * time.Second
if t <= 0 {
return 30 * time.Second
}
return t
}
func (b *baseRouter) Handles(model string) bool {
_, ok := b.processes[model]
return ok
}
func (b *baseRouter) ProcessLogger(modelID string) (*logmon.Monitor, bool) {
if p, ok := b.processes[modelID]; ok {
return p.Logger(), true
}
return nil, false
}
// RunningModels returns the current state of every process that is not stopped
// or shut down. The processes map keys are fixed at construction and State()
// is a snapshot, so this is safe to call without the run loop.
func (b *baseRouter) RunningModels() map[string]process.ProcessState {
running := make(map[string]process.ProcessState)
for id, p := range b.processes {
st := p.State()
if st == process.StateStopped || st == process.StateShutdown {
continue
}
running[id] = st
}
return running
}
// Unload stops the named models, or every running model when none are named.
// It blocks until each targeted process has stopped.
//
// The request is funneled through the run loop so eviction is coordinated
// with the rest of the router's state: pending swap waiters for an
// unloaded model are released with an error, queued requests for unloaded
// models are dropped, and any deferred swaps that were waiting on those
// models become eligible to start.
//
// In-flight requests being served by an unloaded process are not waited
// for — Stop kills the upstream, those callers see whatever error the
// reverse proxy surfaces and may retry. Their trackedServe defers fire
// normally and decrement inFlight as the dying handlers return.
func (b *baseRouter) Unload(timeout time.Duration, models ...string) {
targets := models
if len(targets) == 0 {
targets = make([]string, 0, len(b.processes))
for id := range b.processes {
targets = append(targets, id)
}
}
if len(targets) == 0 {
return
}
req := unloadReq{targets: targets, timeout: timeout, respond: make(chan struct{})}
select {
case b.unloadCh <- req:
case <-b.runDone:
return
}
<-req.respond
}
func (b *baseRouter) Shutdown(timeout time.Duration) error {
if !b.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("%s shutdown already in progress", b.name)
}
req := shutdownReq{timeout: timeout, respond: make(chan error, 1)}
select {
case b.shutdownCh <- req:
case <-b.runDone:
return nil
}
return <-req.respond
}
func (b *baseRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if b.shuttingDown.Load() {
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
data, err := shared.FetchContext(req, b.config)
if err != nil {
shared.SendError(w, req, err)
return
}
hr := scheduler.HandlerReq{
Model: data.ModelID,
Ctx: req.Context(),
// Unbuffered: a successful send on Respond proves the waiter is
// alive and consuming. grant() relies on this to avoid handing a
// handleFunc to a cancelled waiter and leaking the inFlight count.
Respond: make(chan scheduler.HandlerResp),
PositionCh: make(chan int, 1),
}
select {
case b.handlerCh <- hr:
case <-req.Context().Done():
return
case <-b.shutdownCtx.Done():
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
isModelReady := false
if p, ok := b.processes[data.ModelID]; ok {
isModelReady = p.State() == process.StateReady
}
shouldShowLoading := data.Streaming && data.SendLoadingState && isLoadingPath(req.URL.Path) && !isModelReady
var lw *loadingWriter
cancelLoad := func() {}
if shouldShowLoading {
var swapCtx context.Context
swapCtx, cancelLoad = context.WithCancel(req.Context())
lw = newLoadingWriter(b.logger, data.ModelID, w, req)
go lw.start(swapCtx)
go func() {
for {
select {
case pos := <-hr.PositionCh:
lw.setUpdate(fmt.Sprintf("Queue position: #%d", pos))
case <-swapCtx.Done():
return
}
}
}()
}
// finishLoading stops the loading stream and fences its goroutine off from
// the ResponseWriter before the real handler (or ServeHTTP's return)
// reclaims it. release() must run even when waitForCompletion times out:
// otherwise a still-streaming goroutine flushes a finalized response and
// panics on the recycled *bufio.Writer.
finishLoading := func() {
cancelLoad()
if lw != nil {
lw.waitForCompletion(1 * time.Second)
lw.release()
}
}
var resp scheduler.HandlerResp
select {
case resp = <-hr.Respond:
finishLoading()
case <-req.Context().Done():
finishLoading()
// Notify the scheduler so it can prune this request from its queue
// and swap waiters. Without this, a queued request whose client left
// would sit in the scheduler until drainQueue eventually starts a
// wasted model load for it.
select {
case b.cancelCh <- hr:
case <-b.shutdownCtx.Done():
}
return
case <-b.shutdownCtx.Done():
finishLoading()
shared.SendError(w, req, fmt.Errorf("%s is shutting down", b.name))
return
}
if resp.Err != nil {
shared.SendError(w, req, resp.Err)
return
}
resp.HandleFunc(w, req)
}
+264
View File
@@ -0,0 +1,264 @@
package router
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/router/scheduler"
)
// These tests cover baseRouter's own machinery — the run loop, process
// lifecycle (doSwap), grant/ServeHTTP plumbing, Unload, and Shutdown. The
// scheduling decision logic (queueing, collation, eviction collisions) lives in
// the scheduler package and is tested directly there; see fifo_test.go.
// stubPlanner evicts nothing. baseRouter tests drive the run loop through the
// default FIFO scheduler without exercising any particular eviction policy.
type stubPlanner struct{}
func (s *stubPlanner) EvictionFor(string, []string) []string { return nil }
func (s *stubPlanner) OnSwapStart(string, []string) {}
func newTestBase(t *testing.T, processes map[string]process.Process, planner scheduler.Swapper) *baseRouter {
t.Helper()
conf := config.Config{HealthCheckTimeout: 5}
b, err := newBaseRouter("test", conf, processes, logmon.NewWriter(io.Discard), planner)
if err != nil {
t.Fatalf("newBaseRouter: %v", err)
}
b.testProcessed = make(chan struct{}, 64)
go b.run()
t.Cleanup(func() {
if !b.shuttingDown.Load() {
_ = b.Shutdown(time.Second)
}
})
return b
}
func TestBaseRouter_RunningModels(t *testing.T) {
ready := newFakeProcess("ready")
ready.markReady()
starting := newFakeProcess("starting")
starting.setState(process.StateStarting)
stopped := newFakeProcess("stopped")
b := newTestBase(t, map[string]process.Process{
"ready": ready, "starting": starting, "stopped": stopped,
}, &stubPlanner{})
running := b.RunningModels()
if len(running) != 2 {
t.Fatalf("running=%v want 2 entries", running)
}
if running["ready"] != process.StateReady {
t.Errorf("ready state=%q want ready", running["ready"])
}
if running["starting"] != process.StateStarting {
t.Errorf("starting state=%q want starting", running["starting"])
}
if _, ok := running["stopped"]; ok {
t.Errorf("stopped process should be excluded from RunningModels")
}
}
func TestBaseRouter_UnloadAll(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
c := newFakeProcess("c")
c.markReady()
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
b.Unload(time.Second)
if a.State() != process.StateStopped || c.State() != process.StateStopped {
t.Fatalf("Unload() should stop every process: a=%q c=%q", a.State(), c.State())
}
}
func TestBaseRouter_UnloadSpecificModel(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
c := newFakeProcess("c")
c.markReady()
b := newTestBase(t, map[string]process.Process{"a": a, "c": c}, &stubPlanner{})
b.Unload(time.Second, "a")
if a.State() != process.StateStopped {
t.Errorf("a should be stopped, got %q", a.State())
}
if c.State() != process.StateReady {
t.Errorf("c should remain ready, got %q", c.State())
}
}
// TestBaseRouter_Unload_StopsInParallel verifies that Unload fans out its
// Stop calls concurrently rather than stopping each process serially. Each
// fakeProcess.Stop is pinned via stopBlock; the test only releases them
// after observing every stopStarted, proving all three Stops were in
// flight simultaneously.
func TestBaseRouter_Unload_StopsInParallel(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
a.stopBlock = make(chan struct{})
pb := newFakeProcess("b")
pb.markReady()
pb.stopBlock = make(chan struct{})
pc := newFakeProcess("c")
pc.markReady()
pc.stopBlock = make(chan struct{})
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb, "c": pc}, &stubPlanner{})
unloadDone := make(chan struct{})
go func() {
b.Unload(time.Second, "a", "b", "c")
close(unloadDone)
}()
// All three Stop calls must start before any of them are allowed to
// complete. If Unload was serial, only one stopStarted would fire
// until we released its stopBlock, and this would deadlock.
for _, p := range []*fakeProcess{a, pb, pc} {
select {
case <-p.stopStarted:
case <-time.After(2 * time.Second):
t.Fatalf("Stop on %s never started — Unload is not parallel", p.id)
}
}
// Release them; Unload should now return.
close(a.stopBlock)
close(pb.stopBlock)
close(pc.stopBlock)
select {
case <-unloadDone:
case <-time.After(2 * time.Second):
t.Fatal("Unload did not return after stops released")
}
for _, p := range []*fakeProcess{a, pb, pc} {
if p.State() != process.StateStopped {
t.Errorf("%s state=%q want stopped", p.id, p.State())
}
if got := p.stopCalls.Load(); got != 1 {
t.Errorf("%s stopCalls=%d want 1", p.id, got)
}
}
}
func TestBaseRouter_OnDemandStart(t *testing.T) {
a := newFakeProcess("a")
a.autoReady = true
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.runCalls.Load(); got != 1 {
t.Errorf("runCalls=%d want 1", got)
}
if got := a.serveCalls.Load(); got != 1 {
t.Errorf("serveCalls=%d want 1", got)
}
}
func TestBaseRouter_ContextCancel(t *testing.T) {
a := newFakeProcess("a")
// autoReady=false so swap parks forever until we mark ready.
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
ctx, cancel := context.WithCancel(context.Background())
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
b.ServeHTTP(w1, newRequestCtx(ctx, "a"))
close(done1)
}()
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
b.ServeHTTP(w2, newRequest("a"))
close(done2)
}()
waitProcessed(t, b.testProcessed, 2) // both requests joined the active swap
<-a.runStarted
cancel()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("cancelled ServeHTTP did not return after ctx cancel")
}
a.markReady()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("non-cancelled ServeHTTP did not complete after swap")
}
if w2.Code != http.StatusOK {
t.Errorf("second request status=%d body=%q", w2.Code, w2.Body.String())
}
}
func TestBaseRouter_ModelNotFound(t *testing.T) {
a := newFakeProcess("a")
b := newTestBase(t, map[string]process.Process{"a": a}, &stubPlanner{})
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("unknown"))
if w.Code != http.StatusNotFound {
t.Errorf("status=%d want %d body=%q", w.Code, http.StatusNotFound, w.Body.String())
}
}
func TestBaseRouter_Shutdown_StopsAllProcesses(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
pb := newFakeProcess("b")
pb.markReady()
go pb.Run(0)
b := newTestBase(t, map[string]process.Process{"a": a, "b": pb}, &stubPlanner{})
if err := b.Shutdown(time.Second); err != nil {
t.Fatalf("Shutdown: %v", err)
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := pb.stopCalls.Load(); got != 1 {
t.Errorf("b.stopCalls=%d want 1", got)
}
// Subsequent ServeHTTP should report 5xx.
w := httptest.NewRecorder()
b.ServeHTTP(w, newRequest("a"))
if w.Code != http.StatusInternalServerError && w.Code != http.StatusServiceUnavailable {
t.Errorf("post-shutdown status=%d want 5xx body=%q", w.Code, w.Body.String())
}
// Second Shutdown should report already in progress.
if err := b.Shutdown(0); err == nil {
t.Errorf("second Shutdown returned nil, want error")
}
}
+404
View File
@@ -0,0 +1,404 @@
# Router design
A developer tutorial for the `internal/router` package and its `scheduler`
sub-package.
## Intro
A llama-swap router is the component that sits behind the proxy and answers one
question for every incoming request: _can this model serve right now, and if
not, what has to happen first?_ Answering it means juggling three concerns that
used to live tangled together in one type:
1. **Process machinery** — owning the OS processes, starting and stopping them,
running health checks, and shuttling HTTP requests onto the right upstream.
2. **Scheduling strategy** — the queue, in-flight bookkeeping, and the decision
tree that turns one request into "serve now", "join an existing swap",
"queue", or "start a swap".
3. **Eviction policy** — given a model we want to load, which currently-running
models have to be stopped to make room?
The design pulls those three apart into separate, independently replaceable
pieces:
| Concern | Type | Lives in |
| ------------------- | ------------------------------ | ------------------------------- |
| Process machinery | `baseRouter` | `internal/router/base.go` |
| Scheduling strategy | `scheduler.Scheduler` (`FIFO`) | `internal/router/scheduler/` |
| Eviction policy | `scheduler.Swapper` | `groupSwapper`, `matrixSwapper` |
`baseRouter` keeps the channels, run loop, process lifecycle, and shutdown
teardown, and exposes the side-effects a scheduler needs through the
`scheduler.Effects` interface. The scheduler owns the queue and decision tree
but performs no side-effects directly — it calls back through `Effects`. The
`Swapper` is a pure function from "target model + currently running" to "models
to evict", and knows nothing about queues, channels, or processes.
Because the seams are interfaces, you can replace the scheduling strategy
without touching process management, or write a new eviction policy without
touching either. `FIFO` is the first and currently only `Scheduler`;
`groupSwapper` and `matrixSwapper` are the two `Swapper`s.
## Key concepts
### One run loop, no locks
`baseRouter.run()` is a single goroutine selecting over a handful of channels:
```go
for {
select {
case req := <-b.shutdownCh: b.handleShutdown(req); return
case req := <-b.handlerCh: b.schedule.OnRequest(req)
case req := <-b.unloadCh: b.schedule.OnUnload(req.targets, req.timeout); close(req.respond)
case ev := <-b.swapDoneCh: b.schedule.OnSwapDone(ev)
case ev := <-b.serveDoneCh: b.schedule.OnServeDone(ev)
}
}
```
Every `Scheduler` method runs on this one goroutine. That is the single most
important fact about the design: **the scheduler never needs a mutex for its own
state**. All scheduler state is touched only from these callbacks, which are
serialized by the run loop. If you write a new scheduler, you get the same
guarantee for free — and you must not break it by spinning up goroutines that
mutate scheduler state.
### Events flow in, side-effects flow out
The run loop turns external happenings into method calls on the scheduler:
- A new HTTP request becomes `OnRequest(HandlerReq)`.
- A swap goroutine finishing becomes `OnSwapDone(SwapDone)`.
- A tracked request handler returning becomes `OnServeDone(ServeDoneEvent)`.
- An admin unload becomes `OnUnload(targets, timeout)`.
- Shutdown becomes `OnShutdown(err)`.
The scheduler reacts by calling **back out** through `Effects`: inspect a
process state, start a swap, grant a response to a caller, or stop processes. It
never calls `process.Process` directly and never writes to a channel directly.
This keeps the scheduler pure enough to unit-test against a fake `Effects` with
no goroutines or real processes involved (see `scheduler/fifo_test.go`).
```
HTTP request admin Unload / Shutdown
│ │
▼ ▼
ServeHTTP ──HandlerReq──▶ baseRouter.run() ◀──unloadCh/shutdownCh
│ (single goroutine)
Scheduler.On*(...)
│ calls back through
Effects: ModelState / StartSwap /
GrantServe / GrantError / StopProcesses
baseRouter side-effects: doSwap goroutine,
grant() to caller, process.Stop()
swap completes ──SwapDone──▶ back into run loop
```
### The swap goroutine
Scheduling decisions must be quick and non-blocking, but loading a model is
slow. The two are reconciled by doing the slow part on a separate goroutine.
When the scheduler decides to start a swap, inside `OnRequest` it:
1. records "a swap for X is in flight" in its own state, then
2. calls `Effects.StartSwap(modelID, evict)`.
`StartSwap` does **not** load the model itself — it just launches a detached
goroutine (`doSwap`) and returns straight away. `doSwap` is what does the slow
work: stop the evicted processes, start the target, wait for it to become ready.
Because `StartSwap` returned immediately, `OnRequest` returns too, and the run
loop is free to pick up the next event — another request, a serve-done, an
unload — while `doSwap` runs in the background.
The swap's eventual result comes back as just another event: when `doSwap`
finishes it posts a `SwapDone` onto `swapDoneCh`, which the run loop delivers as
`OnSwapDone`. So a slow load never blocks the run loop; it brackets it with two
quick events (`OnRequest` to start, `OnSwapDone` to finish) and everything in
between is handled normally.
### In-flight tracking and `trackedServe`
When the scheduler grants a request, the handler it hands back is wrapped by
`baseRouter.trackedServe`. The wrapper runs the real `ServeHTTP` and, on return,
posts a `ServeDoneEvent` so the run loop can decrement the per-model in-flight
count. This is why the scheduler can know whether a process is "busy": it counts
grants out and serve-dones in. A swap that would evict a busy process is
deferred until that process's in-flight count hits zero (`OnServeDone` then
re-drains the queue).
The subtle contract here is `GrantServe`'s boolean return. The caller's
`Respond` channel is unbuffered, so a successful send proves the HTTP goroutine
is alive and took the handler. If the caller already disconnected, the send
fails, `trackedServe` never runs, and **no** `ServeDoneEvent` will ever arrive —
so the scheduler must only increment `inFlight` when `GrantServe` returns true.
Incrementing on a false return would strand the counter above zero and the model
could never be evicted again.
## The interfaces
All three live in `scheduler/scheduler.go`.
### `Scheduler`
```go
type Scheduler interface {
OnRequest(req HandlerReq)
OnSwapDone(ev SwapDone)
OnServeDone(ev ServeDoneEvent)
OnUnload(targets []string, timeout time.Duration)
OnShutdown(err error)
}
```
Owns the queue, in-flight tracking, and the decision tree. All methods run on
the run-loop goroutine, so no internal locking is needed.
### `Swapper`
```go
type Swapper interface {
EvictionFor(target string, running []string) []string
OnSwapStart(target string, running []string)
}
```
The eviction policy. `EvictionFor` is a **pure decision** — given the target and
the complete `running` set, return the running model IDs that must stop. It must
not log or mutate anything, and it does **not** inspect process state itself:
the scheduler hands it `running` already assembled (every non-stopped process,
unioned with the targets of in-flight swaps already committed but not yet
visible in process state). That keeps the swapper a pure function of its inputs,
with no reference to processes.
The reason it must not log is that it is a _speculative_ query — "what would we
evict if we started this swap right now?" — called far more often than swaps
actually happen. The scheduler calls it once per incoming request, and then
**again for every still-queued request on every queue drain** (each `OnSwapDone`,
`OnServeDone`, and `OnUnload`). Most of those calls end in "still queued",
"collides", or "nothing to evict", not a real swap. Logging there would emit
duplicate lines for a request that simply sits in the queue, and lines for
decisions that never happen — the log would stop meaning "a swap occurred".
`OnSwapStart` is the one place a Swapper may log, because it is called exactly
once, at the moment a swap is committed. One log line there equals one real swap,
with the evict set that is genuinely being applied — which is why `matrixSwapper`
re-solves and logs the full decision (set, DSL, cost) in `OnSwapStart` rather
than in `EvictionFor`.
### `Effects`
```go
type Effects interface {
ModelState(modelID string) (process.ProcessState, bool)
RunningModels() map[string]process.ProcessState
StartSwap(modelID string, evict []string)
GrantError(req HandlerReq, err error)
GrantServe(req HandlerReq, modelID string) bool
StopProcesses(timeout time.Duration, ids []string)
}
```
Implemented by `baseRouter`. This is the scheduler's entire window onto the
outside world; everything else about the router is hidden from it. See the
deep-dive below.
### `Factory` — wiring it together
```go
type Factory func(name string, logger *logmon.Monitor, eff Effects) Scheduler
```
`baseRouter` doesn't know which scheduler or swapper it has — it is handed a
`Factory` at construction and calls it once, passing itself as the `Effects`.
The concrete router captures its `Swapper` in the closure. From `group.go`:
```go
swapper := &groupSwapper{ /* ... */ }
base := newBaseRouter("group", conf, processes, proxylog,
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
return scheduler.NewFIFO(name, logger, swapper, eff)
})
```
This closure is the single point where the three pieces meet: it binds a
specific `Swapper` (`swapper`) and a specific `Scheduler` (`FIFO`) to the
`baseRouter`'s `Effects` (`eff`).
**The swapper is a separate type from the concrete router.** There are currently two router implementations router.Group and router.Matrix. Each of these has a custom swapper that implements scheduler.Swapper for custom eviction logic. This decoupling of responsibilities makes it easy to implement custom swapping strategies.
### The events
A single goroutine in `baseRouter.run()` owns and serializes all state changes in the router. By processing events one at a time it ensures correctness and eliminates complex mutex lock logic.
These are the events the router currently uses:
```go
type HandlerReq struct { // one in-flight ServeHTTP awaiting a decision
Model string
Ctx context.Context
Respond chan HandlerResp // UNBUFFERED — see GrantServe contract
PositionCh chan int // queue-position updates for the loading UI
}
type HandlerResp struct { // the decision handed back to the caller
HandleFunc http.HandlerFunc // serve with this, or...
Err error // ...fail with this
}
type SwapDone struct{ ModelID string; Err error } // swap goroutine finished
type ServeDoneEvent struct{ ModelID string } // tracked handler returned
```
## Deep-dive: the `Effects` interface and why it exists
`Effects` is the inversion-of-control boundary that makes the split possible.
The scheduler decides and `baseRouter` _acts_. Pulling the side-effects behind this
interface buys three things:
1. **Purity and testability.** The scheduler performs no I/O, starts no
goroutines of its own, and touches no real processes. Its tests drive the
`On*` methods directly and assert on a `fakeEffects` that just records the
calls — synchronous, deterministic, no sleeps. (`scheduler/fifo_test.go`.)
2. **A single, auditable side-effect surface.** Every externally-visible thing a
scheduler can do is one of six methods. You can reason about the whole
contract by reading one interface.
3. **Decoupling lifetime.** The scheduler never holds a `process.Process`,
never sees a channel, and never learns how shutdown teardown works. It only
knows model IDs and states.
Method by method, as implemented in `base.go`:
- **`ModelState(modelID) (state, ok)`** — read-only snapshot of a process's
state, and whether this router handles the model at all. The scheduler uses it
for the "unknown model" check and the "already ready" fast path. Safe to call
any time because the process map is fixed at construction and `State()` is a
snapshot.
- **`RunningModels()`** — the state of every process that isn't stopped or shut
down. The scheduler unions its keys with its own in-flight swap targets to
build the `running` set it hands the `Swapper`, so the swapper never has to
touch process state itself.
- **`StartSwap(modelID, evict)`** — fire-and-forget. `baseRouter` launches the
`doSwap` goroutine and returns immediately; the result comes back later as a
`SwapDone`. The scheduler records the swap as active _before_ calling this so
that requests arriving in the meantime can join it.
- **`GrantError(req, err)`** — hand a caller an error response. Used for unknown
models, failed swaps, unloads, and shutdown.
- **`GrantServe(req, modelID) bool`** — hand a caller the tracked handler for a
ready model, returning whether the caller was still there to receive it. The
scheduler increments the in-flight count **only on a true return** (see the
in-flight contract above). This is the one `Effects` method whose return value
carries state-machine significance.
- **`StopProcesses(timeout, ids)`** — stop processes in parallel and **block**
until all have stopped. Used by `OnUnload` so an admin `Unload` call can
guarantee the process is dead by the time it returns. (Note `StartSwap` is
async but `StopProcesses` is sync — the difference is deliberate and tied to
the caller's expectations.)
A useful way to hold it in your head: `Effects` is the scheduler's syscall
table. The scheduler is a pure state machine; `Effects` is how it touches the
world, and `baseRouter` is the kernel that implements those syscalls with real
goroutines, channels, and processes.
## How to implement a new `Swapper`
A `Swapper` is a pure decision function plus a logging hook — the easiest of the three pieces to replace.
1. **Write the swapper type** and give it whatever config it needs to make a
decision. It does **not** need the process map — the scheduler supplies the
running set as an argument. `groupSwapper` holds only its group config;
`matrixSwapper` holds only its solver and logger:
```go
type mySwapper struct {
config config.Config
}
```
2. **Implement `EvictionFor(target, running)`** as a _pure_ decision:
- `running` is the complete live set, already assembled for you: every
non-stopped process unioned with the targets of in-flight swaps the
scheduler has committed to. You don't filter process state or fold in
in-flight targets yourself, that's the scheduler's job. Just decide against the slice you're handed.
- Return the list of model IDs in `running` that must stop for `target` to
run. Return `nil`/empty when nothing needs evicting.
- Do **not** mutate state here.
- Do **not** log here. It can be called multiple times per request. Since it is pure function have tests verify the expected behaviour.
3. **Implement `OnSwapStart(target, running)`** — called once when a swap
actually begins, with the same `running` set `EvictionFor` saw. This is the
right place to log: one call equals one real swap. `matrixSwapper` re-solves
and logs the chosen set and cost here; `groupSwapper` logs nothing.
4. **Wire it in** by instantiating the swapper in your router's constructor and
capturing it in the `Factory` closure passed to `newBaseRouter` — exactly as
`NewGroup` and `NewMatrix` do. The router struct itself only ever embeds
`*baseRouter`; the swapper reaches the scheduler solely through that closure.
Reference implementations: `groupSwapper` (static group config) in `group.go`
and `matrixSwapper` (cost-based set solver) in `matrix.go`.
## How to implement a new `Scheduler`
Replacing the scheduler means taking over the queue and the entire decision tree. Read `scheduler/fifo.go` end to end first — it is the reference implementation and the rules below are easiest to understand in context.
The rules you must honour:
- **Single goroutine.** Every method runs on the `baseRouter.run()` goroutine. Keep your state in plain maps/slices and never read or write it from another goroutine. If you need slow work done, hand it to `Effects.StartSwap` and react to the resulting `SwapDone` — do not block a method waiting for it.
- **Never block the run loop.** `OnRequest`, `OnSwapDone`, and `OnServeDone` must make a decision and return. The one method allowed to block is `OnUnload`, and only because it must wait on the synchronous `StopProcesses` so the admin caller's guarantee holds.
- **Respect the `GrantServe` boolean.** Only count a request as in-flight when `GrantServe` returns true (see the in-flight contract above). A false return means the caller is gone; no `ServeDoneEvent` will ever arrive, so incrementing on false permanently strands the counter.
- **Account for in-flight swaps in your running set.** When you call `Swapper.EvictionFor`, the running set you pass must include not just live processes (`Effects.RunningModels`) but also the targets of swaps you've already started that aren't yet visible in process state — otherwise the swapper contradicts decisions already in motion.
What each method must do:
- **`OnRequest(req)`** — every request must resolve to exactly one of: granted, errored, joined (piggybacks an in-flight swap), queued, or swap-started. No request may be silently dropped.
- **`OnSwapDone(ev)`** — deliver the result to every waiter that joined this swap (grant on success, error on `ev.Err`), drop the swap from active tracking, then re-examine anything queued — a finished swap may have unblocked it.
- **`OnServeDone(ev)`** — decrement the model's in-flight count; when it hits zero, re-examine the queue. Do **not** clear in-flight counts by hand; the handlers post their own `ServeDoneEvent`s on return.
- **`OnUnload(targets, timeout)`** — error out any waiters or queued requests for the unloaded models, call `Effects.StopProcesses` (synchronously — the admin caller relies on the process being dead afterwards), then re-examine the queue.
- **`OnShutdown(err)`** — error out every waiter you still hold (active swap waiters and queued requests). Don't touch processes; teardown is `baseRouter`'s job.
Expose a constructor matching the `Factory` shape:
```go
func NewMyScheduler(name string, logger *logmon.Monitor, swapper Swapper, eff Effects) *MyScheduler {
// ...
}
// in the concrete router:
base := newBaseRouter(name, conf, processes, proxylog,
func(name string, logger *logmon.Monitor, eff scheduler.Effects) scheduler.Scheduler {
return scheduler.NewMyScheduler(name, logger, swapper, eff)
})
```
## Testing
- **Schedulers** are tested as pure state machines in the `scheduler` package:
drive the `On*` methods directly against a `fakeEffects` and assert on the
recorded grants/starts/stops. No goroutines, no sleeps. See
`scheduler/fifo_test.go` as the reference; follow the `TestSchedulerName_<scenario>`
naming convention.
- **`baseRouter` mechanism** (run loop, `grant`/`ServeHTTP`, `Unload`,
`Shutdown`) is tested in `base_test.go`. The run loop exposes a
`testProcessed` channel so tests can wait for an event to be fully processed
instead of sleeping.
- Run new tests with `go test -v -run TestMyScheduler_... ./internal/router/scheduler/`,
then `make test-dev` for a quick `go test` + `staticcheck` pass over `proxy/`.
+106
View File
@@ -0,0 +1,106 @@
package router
import (
"fmt"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type Group struct {
*baseRouter
}
func NewGroup(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Group, error) {
modelToGroup := make(map[string]string)
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
for _, mid := range gcfg.Members {
if existing, dup := modelToGroup[mid]; dup {
return nil, fmt.Errorf("model %q is in multiple groups: %q and %q", mid, existing, gid)
}
modelToGroup[mid] = gid
}
}
swapper := &groupSwapper{
config: conf,
modelToGroup: modelToGroup,
}
processes := make(map[string]process.Process, len(modelToGroup))
base, err := newBaseRouter("group", conf, processes, proxylog, swapper)
if err != nil {
return nil, fmt.Errorf("creating base router: %w", err)
}
for mid := range modelToGroup {
modelCfg, _, ok := conf.FindConfig(mid)
if !ok {
base.shutdownFn()
base.procCancel()
return nil, fmt.Errorf("no model config for %q", mid)
}
procLog := logmon.NewWriter(upstreamlog)
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
if err != nil {
base.shutdownFn()
base.procCancel()
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
}
processes[mid] = p
}
g := &Group{baseRouter: base}
go base.run()
return g, nil
}
// groupSwapper decides evictions from static group configuration.
//
// Same-group siblings are stopped when the group has swap=true. Cross-group
// members are stopped only when the target's group is exclusive; loading a
// model from a non-exclusive group leaves running exclusive groups alone,
// matching the gotcha in the original ProcessGroup behaviour.
type groupSwapper struct {
config config.Config
modelToGroup map[string]string
}
func (p *groupSwapper) EvictionFor(target string, running []string) []string {
tg := p.modelToGroup[target]
tgCfg := p.config.Routing.Router.Settings.Groups[tg]
seen := make(map[string]struct{})
var result []string
consider := func(mID string) {
if mID == target {
return
}
if _, dup := seen[mID]; dup {
return
}
og := p.modelToGroup[mID]
switch {
case og == tg && tgCfg.Swap:
seen[mID] = struct{}{}
result = append(result, mID)
// the previous ProcessGroup behaviour did not unload exclusive groups
// when loading a non-exclusive model. This maintains that gotcha
// for backwards compatibility. The newer swap matrix approach does not
// have this issue.
case og != tg && tgCfg.Exclusive:
if ogCfg := p.config.Routing.Router.Settings.Groups[og]; !ogCfg.Persistent {
seen[mID] = struct{}{}
result = append(result, mID)
}
}
}
for _, mID := range running {
consider(mID)
}
return result
}
func (p *groupSwapper) OnSwapStart(target string, running []string) {}
+335
View File
@@ -0,0 +1,335 @@
package router
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// newTestGroup builds a Group directly from the supplied processes and config,
// bypassing NewGroup's call to process.New.
func newTestGroup(t *testing.T, conf config.Config, processes map[string]process.Process) *Group {
t.Helper()
modelToGroup := make(map[string]string)
for gid, gcfg := range conf.Routing.Router.Settings.Groups {
for _, mid := range gcfg.Members {
modelToGroup[mid] = gid
}
}
swapper := &groupSwapper{
config: conf,
modelToGroup: modelToGroup,
}
base, err := newBaseRouter("group", conf, processes, logmon.NewWriter(io.Discard), swapper)
if err != nil {
t.Fatalf("newBaseRouter: %v", err)
}
base.testProcessed = make(chan struct{}, 64)
g := &Group{baseRouter: base}
go base.run()
t.Cleanup(func() {
if !g.shuttingDown.Load() {
_ = g.Shutdown(time.Second)
}
})
return g
}
func TestGroup_NewGroup_DuplicateMembership(t *testing.T) {
conf := config.Config{
Routing: groupRouting(map[string]config.GroupConfig{
"g1": {Swap: true, Members: []string{"a"}},
"g2": {Swap: true, Members: []string{"a"}},
}),
Models: map[string]config.ModelConfig{
"a": {},
},
}
log := logmon.NewWriter(io.Discard)
if _, err := NewGroup(conf, log, log); err == nil {
t.Fatalf("expected error for duplicate membership")
}
}
func TestGroup_ServeHTTP_SwapStopsPrevious(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0) // park a Run goroutine so Stop has something to release
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"g": {Swap: true, Exclusive: true, Members: []string{"a", "b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
if got := b.serveCalls.Load(); got != 1 {
t.Errorf("b.serveCalls=%d want 1", got)
}
}
func TestGroup_NonSwapGroup_NoStop(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"g": {Swap: false, Exclusive: false, Members: []string{"a", "b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (swap=false should not stop siblings)", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
func TestGroup_CrossGroupExclusive(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: true, Members: []string{"b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (cross-group exclusive must stop)", got)
}
}
// TestGroup_CrossGroupNonExclusiveParallel verifies that two requests for
// models in distinct non-exclusive groups load in parallel rather than
// serializing through the router's run loop.
func TestGroup_CrossGroupNonExclusiveParallel(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: false, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
g.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, g.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
g.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, g.testProcessed, 1)
// Both groups load concurrently — both must reach Run() before either is
// marked ready. If the router still serialised, only one would proceed.
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (parallel groups don't evict each other)", got)
}
}
// TestGroup_SameGroupSwapSerialises verifies that two same-group requests
// (Swap=true) serialise even when both arrive while neither has reached
// StateStarting yet — the in-flight swap target the scheduler folds into the
// running set closes that race.
func TestGroup_SameGroupSwapSerialises(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"g": {Swap: true, Exclusive: false, Members: []string{"a", "b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
g.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, g.testProcessed, 1)
// Request B arrives before A transitions to StateStarting in the process
// state machine. Without folding the in-flight swap target into the running
// set, the swapper would not see A as running, and B would start in
// parallel, violating Swap=true.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
g.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, g.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started in parallel: runCalls=%d want 0", got)
}
<-a.runStarted
a.markReady()
waitProcessed(t, g.testProcessed, 1) // swapDone(a) → b promoted
<-pb.runStarted
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
}
}
// TestGroup_PersistentNotEvicted verifies that a group with persistent=true
// is never evicted when another exclusive group starts loading. The running
// model in the persistent group stays alive alongside the new one.
func TestGroup_PersistentNotEvicted(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"persist": {Swap: true, Exclusive: false, Persistent: true, Members: []string{"a"}},
"other": {Swap: true, Exclusive: true, Members: []string{"b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (persistent group must not be evicted)", got)
}
if a.State() != process.StateStarting && a.State() != process.StateReady {
t.Errorf("a state=%s want still running", a.State())
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestGroup_NonExclusiveDoesNotUnloadExclusive pins a backwards-compatible
// gotcha from the original ProcessGroup: when a model in a non-exclusive group
// is loaded, any running exclusive group keeps running. The two coexist.
func TestGroup_NonExclusiveDoesNotUnloadExclusive(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
conf := config.Config{
HealthCheckTimeout: 5,
Routing: groupRouting(map[string]config.GroupConfig{
"g1": {Swap: true, Exclusive: true, Members: []string{"a"}},
"g2": {Swap: true, Exclusive: false, Members: []string{"b"}},
}),
}
g := newTestGroup(t, conf, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
g.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (non-exclusive target must not unload exclusive group)", got)
}
if a.State() != process.StateStarting && a.State() != process.StateReady {
t.Errorf("a state=%s want still running", a.State())
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
+218
View File
@@ -0,0 +1,218 @@
package router
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// groupRouting builds a normalized RoutingConfig for the group router, mirroring
// what config.LoadConfigFromReader produces. Tests use it to populate
// config.Config.Routing without going through LoadConfig.
func groupRouting(groups map[string]config.GroupConfig) config.RoutingConfig {
return config.RoutingConfig{
Router: config.RouterConfig{
Use: "group",
Settings: config.RouterSettings{Groups: groups},
},
}
}
// fakeProcess is an in-memory implementation of process.Process used to drive
// the routers through their state machine without spawning real upstreams.
type fakeProcess struct {
id string
mu sync.Mutex
state process.ProcessState
readyCh chan struct{}
stopCh chan struct{}
runStarted chan struct{} // closed on the first Run call
stopStarted chan struct{} // closed on the first Stop call
autoReady bool
// serveBlock, when non-nil, makes ServeHTTP receive from it before
// writing its response. Tests use this to hold a request in-flight.
// Closing the channel releases every blocked ServeHTTP caller.
serveBlock chan struct{}
// serveStarted is closed on the first ServeHTTP entry, letting tests
// wait deterministically for the handler to begin executing.
serveStarted chan struct{}
// stopBlock, when non-nil, makes Stop receive from it (after signalling
// stopStarted) before completing. Tests use this to prove that several
// Stop calls can be in flight simultaneously.
stopBlock chan struct{}
runCalls atomic.Int32
stopCalls atomic.Int32
serveCalls atomic.Int32
// inFlightServe counts ServeHTTP calls currently inside the handler.
// stoppedWhileServing flips true if Stop is ever called while that
// counter is non-zero — a direct, race-free observation of the
// "swap mid-request" anti-property.
inFlightServe atomic.Int32
stoppedWhileServing atomic.Bool
}
func newFakeProcess(id string) *fakeProcess {
return &fakeProcess{
id: id,
state: process.StateStopped,
readyCh: make(chan struct{}),
stopCh: make(chan struct{}),
runStarted: make(chan struct{}),
stopStarted: make(chan struct{}),
serveStarted: make(chan struct{}),
}
}
func (f *fakeProcess) setState(s process.ProcessState) {
f.mu.Lock()
defer f.mu.Unlock()
f.state = s
if s == process.StateReady {
select {
case <-f.readyCh:
default:
close(f.readyCh)
}
}
}
func (f *fakeProcess) State() process.ProcessState {
f.mu.Lock()
defer f.mu.Unlock()
return f.state
}
func (f *fakeProcess) markReady() { f.setState(process.StateReady) }
func (f *fakeProcess) Run(_ time.Duration) error {
f.runCalls.Add(1)
f.mu.Lock()
if f.state != process.StateStopped {
s := f.state
f.mu.Unlock()
return fmt.Errorf("fakeProcess %s: Run called while %s", f.id, s)
}
f.state = process.StateStarting
sc := f.stopCh
select {
case <-f.runStarted:
default:
close(f.runStarted)
}
f.mu.Unlock()
if f.autoReady {
f.setState(process.StateReady)
}
<-sc
return nil
}
func (f *fakeProcess) Stop(_ time.Duration) error {
f.stopCalls.Add(1)
if f.inFlightServe.Load() > 0 {
f.stoppedWhileServing.Store(true)
}
f.mu.Lock()
select {
case <-f.stopStarted:
default:
close(f.stopStarted)
}
f.mu.Unlock()
// Test hook: hold Stop here so the test can prove multiple Stops are
// in flight at the same time before any of them complete.
if f.stopBlock != nil {
<-f.stopBlock
}
f.mu.Lock()
defer f.mu.Unlock()
if f.state == process.StateStopped {
return nil
}
f.state = process.StateStopped
select {
case <-f.stopCh:
default:
close(f.stopCh)
}
return nil
}
func (f *fakeProcess) WaitReady(ctx context.Context) error {
f.mu.Lock()
if f.state == process.StateReady {
f.mu.Unlock()
return nil
}
rc := f.readyCh
f.mu.Unlock()
select {
case <-rc:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (f *fakeProcess) Logger() *logmon.Monitor { return logmon.NewWriter(io.Discard) }
func (f *fakeProcess) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
f.serveCalls.Add(1)
f.inFlightServe.Add(1)
defer f.inFlightServe.Add(-1)
f.mu.Lock()
select {
case <-f.serveStarted:
default:
close(f.serveStarted)
}
f.mu.Unlock()
if f.serveBlock != nil {
<-f.serveBlock
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "ok:%s", f.id)
}
// waitProcessed drains n events from ch, fataling on timeout. One event fires
// per handlerReq or swapDone fully absorbed by run().
func waitProcessed(t *testing.T, ch chan struct{}, n int) {
t.Helper()
for i := 0; i < n; i++ {
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatalf("waitProcessed: only %d/%d events received", i, n)
}
}
}
func newRequest(model string) *http.Request {
body := fmt.Sprintf(`{"model":%q}`, model)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
return r
}
func newRequestCtx(ctx context.Context, model string) *http.Request {
return newRequest(model).WithContext(ctx)
}
+277
View File
@@ -0,0 +1,277 @@
package router
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
var loadingPaths = []string{
"/v1/chat/completions",
}
func isLoadingPath(path string) bool {
for _, p := range loadingPaths {
if strings.HasPrefix(path, p) {
return true
}
}
return false
}
type loadingWriter struct {
hasWritten bool
writer http.ResponseWriter
req *http.Request
ctx context.Context
logger *logmon.Monitor
modelName string
startTime time.Time
pendingMu sync.Mutex
pendingUpdate string
// writeMu serializes writes to the underlying writer and guards released.
// Once released is set, the streaming goroutine must not touch the writer
// again — ServeHTTP has reclaimed it (to run the real handler or to return)
// and writing/flushing a finalized response panics.
writeMu sync.Mutex
released bool
// closed by start when the goroutine finishes (after cleanup messages)
done chan struct{}
// test-only: closed when start enters its loop
loopStarted chan struct{}
// test-only: override the 1s tick interval
tickDuration time.Duration
// test-only: override character streaming speed (0 = no delay)
charPerSecond float64
}
func newLoadingWriter(logger *logmon.Monitor, modelName string, w http.ResponseWriter, req *http.Request) *loadingWriter {
s := &loadingWriter{
writer: w,
req: req,
ctx: req.Context(),
logger: logger,
modelName: modelName,
startTime: time.Now(),
tickDuration: 750 * time.Millisecond,
charPerSecond: 75,
}
s.Header().Set("Content-Type", "text/event-stream")
s.Header().Set("Cache-Control", "no-cache")
s.Header().Set("Connection", "keep-alive")
s.WriteHeader(http.StatusOK)
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", modelName))
return s
}
func (s *loadingWriter) setUpdate(msg string) {
s.pendingMu.Lock()
s.pendingUpdate = msg
s.pendingMu.Unlock()
}
func (s *loadingWriter) start(ctx context.Context) {
s.done = make(chan struct{})
defer close(s.done)
defer func() {
// Skip cleanup writes if the client disconnected — the connection
// is being torn down and flushing against it will panic.
if s.ctx.Err() != nil {
return
}
duration := time.Since(s.startTime)
s.sendData("\n")
s.sendLine(fmt.Sprintf("Done! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Time{}
ticker := time.NewTicker(s.tickDuration)
defer ticker.Stop()
if s.loopStarted != nil {
close(s.loopStarted)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.pendingMu.Lock()
update := s.pendingUpdate
s.pendingUpdate = ""
s.pendingMu.Unlock()
if update != "" {
s.sendData("\n")
s.sendInline(update)
s.sendData(" ")
lastRemarkTime = time.Now()
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendData("\n")
s.sendInline(remark)
s.sendData(" ")
lastRemarkTime = time.Now()
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
func (s *loadingWriter) waitForCompletion(timeout time.Duration) bool {
if s.done == nil {
return true
}
select {
case <-s.done:
return true
case <-time.After(timeout):
return false
}
}
func (s *loadingWriter) sendInline(text string) {
chunkSize := 10
if s.charPerSecond > 0 {
chunkSize = max(3, int(s.charPerSecond)/15)
}
runes := []rune(text)
for i := 0; i < len(runes); {
select {
case <-s.ctx.Done():
return
default:
}
end := i + chunkSize
if end > len(runes) {
end = len(runes)
}
chunk := string(runes[i:end])
s.sendData(chunk)
i = end
if i < len(runes) && s.charPerSecond > 0 {
time.Sleep(time.Duration(float64(time.Second) * float64(len(chunk)) / s.charPerSecond))
}
}
}
func (s *loadingWriter) sendLine(line string) {
if line == "" {
s.sendData("\n")
return
}
s.sendInline(line)
s.sendData("\n")
}
func (s *loadingWriter) sendData(data string) {
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.logger.Errorf("<%s> Failed to marshal SSE message: %v", s.modelName, err)
return
}
s.writeMu.Lock()
defer s.writeMu.Unlock()
// Once ServeHTTP has reclaimed the writer (release), writing/flushing it
// races the real handler or panics on a finalized response. Stop here.
if s.released {
return
}
if _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData); err != nil {
s.logger.Debugf("<%s> Failed to write SSE data (client likely disconnected): %v", s.modelName, err)
return
}
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
// release fences the loadingWriter off from the underlying ResponseWriter.
// After it returns, the streaming goroutine will not write to or flush the
// writer again: any in-flight write completes under writeMu first, and later
// writes short-circuit on released. The caller can then safely hand the writer
// to the real handler or let ServeHTTP return without racing a finalized
// response (a use-after-return Flush panics on the recycled *bufio.Writer).
func (s *loadingWriter) release() {
s.writeMu.Lock()
s.released = true
s.writeMu.Unlock()
}
func (s *loadingWriter) Header() http.Header {
return s.writer.Header()
}
func (s *loadingWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *loadingWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
func (s *loadingWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+133
View File
@@ -0,0 +1,133 @@
package router
var loadingRemarks = []string{
"Still faster than your last standup meeting",
"Reticulating splines",
"Waking up the hamsters",
"Teaching the model manners",
"Convincing the GPU to participate",
"Loading weights (they're heavy)",
"Please enjoy this elevator music in your head",
"Pretending to be productive",
"Reading the entire internet, page by page",
"Staring at the abyss, the abyss is buffering",
"Applying layer after layer of disembodied cognition",
"Remembering everything it forgot during quantization",
"Counting to 405 billion, one parameter at a time",
"Summoning the stochastic parroting",
"Hold on, the GPU is questioning its existence",
"Deciding which facts to hallucinate today",
"Untangling the transformer spaghetti",
"Warming up the token soup",
"Your prompt is in a queue, behind 7 billion other thoughts",
"Running `sudo apt-get install intelligence`",
"Defragmenting the latent space",
"Polishing each matrix multiplication by hand",
"Whispering sweet nothings to the attention heads",
"Aligning with human values, one reluctant epoch at a time",
"The model is thinking about what it's about to think about",
"Loading... and by loading we mean making you wait",
"Spinning up the cloud GPU, please be patient while we burn your credits",
"Applying duct tape to the context window",
"Bribing the GPU scheduler for a timeslice",
"Would you like to hear a fun fact while we load? Too bad.",
"Hot swapping your sanity for an LLM",
"Compressing optimism into FP16",
"Ignoring 90% of the attention to save you 50% of the time",
"Counting the exact same thing three times just to be sure",
"Sorry, the inference you have reached is not in service",
"Rotating the positional encodings counterclockwise for good luck",
"Your call is very important to us. Please continue to hold.",
"Unpacking the blobs. All 300GB of them.",
"Initializing the thing that initializes the other thing",
"Converting electricity into existential dread",
"Flattening the curve... wait, the tensor. Flattening the tensor.",
"Fetching the fetch of a fetch, callback hell edition",
"The GPU is at 100%. The fan is now a helicopter.",
"Baking the weights at 350° for a golden-brown inference",
"Recalibrating the confidence of things it's still wrong about",
"Have you tried turning it off and on again? No? Good, wait here.",
"Simulating deep thought by pausing dramatically",
"Loading the model that knows more than you but still can't count r's in 'strawberry'",
"Convincing CUDA to cooperate. This may take a while.",
"VRAM: 23.9GB used of 24GB. Living on the edge.",
"Processing your request with the urgency of a DMV employee",
"This model was trained on the entire internet, including that embarrassing blog you wrote in 2008",
"Dispatching tokens through a series of increasingly confused matrix multiplies",
"Gently lowering your expectations",
"Applying softmax to our feelings about this load time",
"Autoregressively generating disappointment, one token at a time",
"The magic is happening. Somewhere. Probably.",
"Synchronizing the parallel processes that run in parallel but really don't",
"Calculating the meaning of life. Spoiler: it's 42, but we're double-checking.",
"Loading... just like it said 30 seconds ago. And will say 30 seconds from now.",
"Pre-warming the cache so the first query is only slightly slower than the rest",
"Have you considered that maybe your question wasn't worth all this compute?",
"Downloading more RAM (no, really, we're mmap-ing the weights)",
"Translating your prompt into math it barely understands",
"Estimating your time remaining with 0% accuracy",
"Buffering enthusiasm",
"Model is loading. Go make some coffee. Or a three-course meal.",
"Tokenizing the dictionary, filing a grievance on behalf of 'antidisestablishmentarianism'",
"Polling for readiness in a loop that would make your CS professor weep",
"Performing percussive maintenance on the attention mechanism",
"This loading screen is singlehandedly reversing climate progress",
"Decompressing the hopes and dreams of thousands of underpaid labelers",
"Filling the key-value cache with the ghost of prompts past",
"Currently at step 3 of 9,742 of loading. We'll get there. Eventually.",
"If you stare at the spinner, it spins slower. It's science.",
"Multiplying matricies with the enthusiasm of a teenager doing chores",
"Applying `torch.nap()` until the model feels refreshed",
"Reacquainting the model with the concept of 'facts' it forgot during fine-tuning",
"Sorry for the wait. No, wait, we're not actually sorry.",
"Your GPU is now a space heater with a side hustle in linear algebra",
"Allocating memory like a billionaire allocates tax avoidance strategies",
"The model saw \"As an AI language model\" and won't stop saying it now",
"Installing dependencies you didn't know existed and will never use again",
"Re-reading 'Attention Is All You Need' for the 400th time",
"Convincing the embedding layer that context is overrated",
"Manually untangling the residual connections with a tiny comb",
"On hold with the cloud provider trying to explain why 8 H100s isn't enough",
"Adjusting temperatures: model is 0.7, server room is 104°F",
"Please hold while we justify this electricity bill to accounting",
"Stacking decoder blocks like a Jenga tower at a LAN party",
"Compensating for your lack of patience with our lack of speed",
"This is a loading screen comment. Loading screens have comments now. Welcome to the future.",
"Processing the entire works of Shakespeare backwards just in case",
"The model is loading slower than your last `npm install`",
"Rehearsing plausible-sounding explanations for why it got everything wrong",
"Populating the context with filler while you wait for actual content",
"Optimizing for BLEU score, which definitely correlates with making you laugh",
"Generating an embedding for each and every letter of the alphabet, individually",
"Coming soon: llama-swap v2 with actual performance improvements. Probably.",
"Loading a model larger than your attention span",
"Performing a seance to invoke the spirit of Geoff Hinton",
"Did you know loading screens were invented to prevent users from smashing their monitors? Now you do.",
"Converting all the internet's bad opinions into a surprisingly useful autocomplete",
"Laying down each layer with the care of a Michelin-starred pastry chef",
"Checking if the model still thinks birds are government drones. Yep.",
"Activating the neurons responsible for 'I cannot assist with that request'",
"This model was trained on the same internet that brought you Rickrolling. You're welcome.",
"Realigning the alignment so it aligns with the previous alignment",
"Running `nvidia-smi` and sighing heavily",
"If you close your eyes, the loading bar moves faster. Proven by science.",
"EULA said 'by using this software you agree to wait forever' and you clicked Accept",
"Zipping the GPUs to make them go faster",
"Padding the context window with existential padding",
"We could have used a smaller model but someone wanted 'quality'",
"Disentangling the latent space into something resembling coherence",
"Slow is smooth, smooth is fast, but this is just slow",
"Memory-mapping like it's a AAA title from 2012",
"Your patience has been tokenized and added to the training set. Thank you for your contribution.",
"Loading is CPU-bound and your CPU is busy regretting its life choices",
"Exploring the high-dimensional manifold of ways to say 'just a moment'",
"The model is experiencing a brief but intense moment of imposter syndrome",
"Initializing 7B parameters by rolling 7B 16-sided dice",
"Panic! at the disk I/O",
"Intelligence is loading... your definition of intelligence may vary",
"This model was distilled. Unlike your patience, which is evaporating.",
"Unzipping the model. It's a .gguf file, not a metaphor.",
"Running inference on the concept of 'soon' to estimate remaining time",
"Loading with all the speed of a government-funded IT project",
"A blank terminal is a terrible thing to waste. Here's a loading message instead.",
}
+265
View File
@@ -0,0 +1,265 @@
package router
import (
"bufio"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
func TestLoadingWriter_SSEHeadersAndInitialMessage(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if ct := lw.Header().Get("Content-Type"); ct != "text/event-stream" {
t.Errorf("Content-Type: want text/event-stream, got %q", ct)
}
if cc := lw.Header().Get("Cache-Control"); cc != "no-cache" {
t.Errorf("Cache-Control: want no-cache, got %q", cc)
}
if conn := lw.Header().Get("Connection"); conn != "keep-alive" {
t.Errorf("Connection: want keep-alive, got %q", conn)
}
body := w.Body.String()
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected SSE data: prefix, got: %s", body)
}
content := extractStreamedContent(body)
if !strings.Contains(content, "━━━━━\n") {
t.Errorf("missing separator in streamed content: %q", content)
}
if !strings.Contains(content, "llama-swap loading model: test-model\n") {
t.Errorf("missing initial message in streamed content: %q", content)
}
}
func TestLoadingWriter_WriteHeaderOnce(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.WriteHeader(http.StatusCreated)
if w.Code != http.StatusOK {
t.Errorf("first WriteHeader: want %d, got %d", http.StatusOK, w.Code)
}
}
func TestLoadingWriter_WritePassthrough(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.Write([]byte("hello"))
lw.Flush()
body := w.Body.String()
if !strings.Contains(body, "hello") {
t.Errorf("Write passthrough failed, body: %s", body)
}
}
func TestLoadingWriter_StartStopsOnCancel(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
if !strings.Contains(body, "Done!") {
t.Errorf("expected Done! message, body: %s", body)
}
}
func TestLoadingWriter_StartShowsSetUpdate(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go lw.start(ctx)
<-lw.loopStarted
lw.setUpdate("custom status message")
time.Sleep(50 * time.Millisecond)
cancel()
if !lw.waitForCompletion(time.Second) {
t.Fatal("waitForCompletion timed out")
}
body := w.Body.String()
content := extractStreamedContent(body)
if !strings.Contains(content, "custom status message") {
t.Errorf("expected setUpdate message in output, got: %q", content)
}
}
func TestLoadingWriter_SendDataFormat(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.sendData("hello world")
body := w.Body.String()
if !strings.Contains(body, `"reasoning_content":"hello world"`) {
t.Errorf("expected reasoning_content in SSE data, body: %s", body)
}
if !strings.HasPrefix(body, "data: ") {
t.Errorf("expected data: prefix, got: %s", body)
}
}
func TestLoadingWriter_SendLine(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.charPerSecond = 0
// Capture only the content from this sendLine call
before := w.Body.Len()
lw.sendLine("line content")
after := w.Body.Len()
chunkBody := w.Body.String()[before:after]
content := extractStreamedContent(chunkBody)
if content != "line content\n" {
t.Errorf("expected complete streamed line, got: %q", content)
}
}
func TestLoadingWriter_FlushesPeriodicallyDuringStatusUpdates(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
lw.tickDuration = 10 * time.Millisecond
lw.charPerSecond = 0
lw.loopStarted = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
lw.start(ctx)
close(done)
}()
<-lw.loopStarted
time.Sleep(50 * time.Millisecond)
cancel()
<-done
body := w.Body.String()
lines := countSSEMessages(body)
if lines < 2 {
t.Errorf("expected multiple SSE messages from periodic updates, got %d", lines)
}
}
func TestLoadingWriter_ReqStored(t *testing.T) {
logger := logmon.NewWriter(io.Discard)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
lw := newLoadingWriter(logger, "test-model", w, req)
if lw.req != req {
t.Fatal("req not stored")
}
}
func TestIsLoadingPath(t *testing.T) {
tests := []struct {
path string
want bool
}{
{"/v1/chat/completions", true},
{"/v1/chat/completions/extra", true},
{"/v1/completions", false},
{"/v1/embeddings", false},
{"/health", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := isLoadingPath(tt.path); got != tt.want {
t.Errorf("isLoadingPath(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func countSSEMessages(s string) int {
scanner := bufio.NewScanner(strings.NewReader(s))
count := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
count++
}
}
return count
}
func extractStreamedContent(body string) string {
var result strings.Builder
scanner := bufio.NewScanner(strings.NewReader(body))
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonData := strings.TrimPrefix(line, "data: ")
var msg struct {
Choices []struct {
Delta struct {
ReasoningContent string `json:"reasoning_content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(jsonData), &msg); err != nil {
continue
}
if len(msg.Choices) > 0 {
result.WriteString(msg.Choices[0].Delta.ReasoningContent)
}
}
return result.String()
}
+72
View File
@@ -0,0 +1,72 @@
package router
import (
"fmt"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
type Matrix struct {
*baseRouter
}
func NewMatrix(conf config.Config, proxylog, upstreamlog *logmon.Monitor) (*Matrix, error) {
mtx := conf.Routing.Router.Settings.Matrix
if mtx == nil {
return nil, fmt.Errorf("matrix router requires a matrix configuration")
}
swapper := &matrixSwapper{
solver: newMatrixSolver(mtx.ExpandedSets, mtx.ResolvedEvictCosts()),
logger: proxylog,
}
// Build a process for every model in the config. Any model can run alone
// even if it is not part of a set; this mirrors proxy.NewMatrix.
processes := make(map[string]process.Process, len(conf.Models))
base, err := newBaseRouter("matrix", conf, processes, proxylog, swapper)
if err != nil {
return nil, fmt.Errorf("creating base router: %w", err)
}
for mid, modelCfg := range conf.Models {
procLog := logmon.NewWriter(upstreamlog)
p, err := process.New(base.procCtx, mid, modelCfg, procLog, proxylog)
if err != nil {
base.shutdownFn()
base.procCancel()
return nil, fmt.Errorf("creating process for %q: %w", mid, err)
}
processes[mid] = p
}
r := &Matrix{baseRouter: base}
go base.run()
return r, nil
}
// matrixSwapper decides evictions by asking the matrix solver against the
// running set the scheduler hands it.
type matrixSwapper struct {
solver *matrixSolver
logger *logmon.Monitor
}
func (p *matrixSwapper) EvictionFor(target string, running []string) []string {
return p.solver.Solve(target, running).Evict
}
func (p *matrixSwapper) OnSwapStart(target string, running []string) {
result := p.solver.Solve(target, running)
switch {
case len(result.Evict) > 0:
p.logger.Infof("matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
target, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
case len(running) == 0:
p.logger.Infof("matrix: model=%s starting (no models running)", target)
default:
p.logger.Debugf("matrix: model=%s already running in set=%s dsl=%q", target, result.SetName, result.DSL)
}
}
+132
View File
@@ -0,0 +1,132 @@
package router
import (
"slices"
"github.com/mostlygeek/llama-swap/internal/config"
)
// matrixSolver contains pure swap-decision logic with no Process dependencies.
// It is safe for concurrent reads after construction.
type matrixSolver struct {
expandedSets []config.ExpandedSet // all valid model combinations
evictCosts map[string]int // real model name -> eviction cost (default 1)
modelToSets map[string][]int // model name -> indices into expandedSets
}
func newMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *matrixSolver {
modelToSets := make(map[string][]int)
for i, es := range expandedSets {
for _, model := range es.Models {
modelToSets[model] = append(modelToSets[model], i)
}
}
return &matrixSolver{
expandedSets: expandedSets,
evictCosts: evictCosts,
modelToSets: modelToSets,
}
}
// solveResult describes what the solver decided.
type solveResult struct {
Evict []string // running models that must be stopped
TargetSet []string // the chosen set of models (for informational purposes)
SetName string // name of the chosen set
DSL string // original DSL expression for the chosen set
TotalCost int // total eviction cost
}
// Solve determines which models to evict when a model is requested.
//
// Algorithm:
// 1. If requestedModel is already running, no eviction needed.
// 2. Find all sets containing requestedModel.
// 3. If no sets found, the model runs alone; evict all running models.
// 4. For each candidate set, compute cost = sum of evict_costs for running
// models NOT in that set.
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
// 6. Return models to evict and the chosen set.
func (s *matrixSolver) Solve(requestedModel string, runningModels []string) solveResult {
if slices.Contains(runningModels, requestedModel) {
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
return solveResult{
TargetSet: runningModels,
SetName: setName,
DSL: dsl,
}
}
candidateIndices := s.modelToSets[requestedModel]
// Model not in any set: runs alone, evict everything.
if len(candidateIndices) == 0 {
evict := make([]string, len(runningModels))
copy(evict, runningModels)
return solveResult{
Evict: evict,
TargetSet: []string{requestedModel},
}
}
bestCost := -1
bestIdx := -1
for _, idx := range candidateIndices {
setModels := s.expandedSets[idx].Models
cost := 0
for _, running := range runningModels {
if !slices.Contains(setModels, running) {
cost += s.evictCost(running)
}
}
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
bestCost = cost
bestIdx = idx
}
}
chosen := s.expandedSets[bestIdx]
var evict []string
for _, running := range runningModels {
if !slices.Contains(chosen.Models, running) {
evict = append(evict, running)
}
}
return solveResult{
Evict: evict,
TargetSet: chosen.Models,
SetName: chosen.SetName,
DSL: chosen.DSL,
TotalCost: bestCost,
}
}
// findMatchingSet finds the expanded set that contains all running models.
// Returns the set name and DSL, or empty strings if no match.
func (s *matrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
for _, idx := range s.modelToSets[requestedModel] {
set := s.expandedSets[idx]
allInSet := true
for _, m := range runningModels {
if !slices.Contains(set.Models, m) {
allInSet = false
break
}
}
if allInSet {
return set.SetName, set.DSL
}
}
return "", ""
}
func (s *matrixSolver) evictCost(model string) int {
if cost, ok := s.evictCosts[model]; ok {
return cost
}
return 1
}
+247
View File
@@ -0,0 +1,247 @@
package router
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// newTestMatrix builds a Matrix router from supplied processes, bypassing
// NewMatrix's call to process.New.
func newTestMatrix(t *testing.T, conf config.Config, expanded []config.ExpandedSet, evictCosts map[string]int, processes map[string]process.Process) *Matrix {
t.Helper()
logger := logmon.NewWriter(io.Discard)
swapper := &matrixSwapper{
solver: newMatrixSolver(expanded, evictCosts),
logger: logger,
}
base, err := newBaseRouter("matrix", conf, processes, logger, swapper)
if err != nil {
t.Fatalf("newBaseRouter: %v", err)
}
base.testProcessed = make(chan struct{}, 64)
r := &Matrix{baseRouter: base}
go base.run()
t.Cleanup(func() {
if !r.shuttingDown.Load() {
_ = r.Shutdown(time.Second)
}
})
return r
}
func baseMatrixConfig() config.Config {
return config.Config{
HealthCheckTimeout: 5,
Matrix: &config.MatrixConfig{},
}
}
// TestMatrix_SwapEvictsConflicting verifies that loading a model triggers
// eviction of running models that are not in any shared set with it.
func TestMatrix_SwapEvictsConflicting(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0) // park a Run goroutine so Stop has something to release
b := newFakeProcess("b")
b.autoReady = true
// Two single-model sets: a and b never coexist, so loading b must evict a.
expanded := []config.ExpandedSet{
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
r.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestMatrix_CoexistInSet verifies that a model is not evicted when the target
// shares a set with it (the fast path applies if the target is already ready).
func TestMatrix_CoexistInSet(t *testing.T) {
a := newFakeProcess("a")
a.markReady()
go a.Run(0)
b := newFakeProcess("b")
b.autoReady = true
// Both fit in s_ab, so b's swap should not stop a.
expanded := []config.ExpandedSet{
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": b})
w := httptest.NewRecorder()
r.ServeHTTP(w, newRequest("b"))
if w.Code != http.StatusOK {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
}
if got := b.runCalls.Load(); got != 1 {
t.Errorf("b.runCalls=%d want 1", got)
}
}
// TestMatrix_CoexistingSetParallel verifies that two models that share an
// expanded set load in parallel — the solver returns empty Evict for both,
// the collision predicate clears them, and both swaps run together.
func TestMatrix_CoexistingSetParallel(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
expanded := []config.ExpandedSet{
{SetName: "s_ab", DSL: "a & b", Models: []string{"a", "b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
r.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, r.testProcessed, 1)
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
r.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, r.testProcessed, 1)
<-a.runStarted
<-pb.runStarted
a.markReady()
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 0 {
t.Errorf("a.stopCalls=%d want 0 (coexists with b)", got)
}
if got := pb.stopCalls.Load(); got != 0 {
t.Errorf("b.stopCalls=%d want 0 (coexists with a)", got)
}
}
// TestMatrix_IncompatibleQueues verifies that the second request for a model
// that cannot coexist with the in-flight first model queues until the first
// completes, and then evicts it. This exercises the scheduler folding in-flight
// swap targets into the running set it hands the swapper.
func TestMatrix_IncompatibleQueues(t *testing.T) {
a := newFakeProcess("a")
pb := newFakeProcess("b")
expanded := []config.ExpandedSet{
{SetName: "s_a", DSL: "a", Models: []string{"a"}},
{SetName: "s_b", DSL: "b", Models: []string{"b"}},
}
r := newTestMatrix(t, baseMatrixConfig(), expanded, nil, map[string]process.Process{"a": a, "b": pb})
w1 := httptest.NewRecorder()
done1 := make(chan struct{})
go func() {
r.ServeHTTP(w1, newRequest("a"))
close(done1)
}()
waitProcessed(t, r.testProcessed, 1)
// B arrives before A transitions to StateStarting. The running set the
// scheduler builds includes A (an in-flight swap target), so the solver
// returns evict=[a] and collidesWith forces B to queue.
w2 := httptest.NewRecorder()
done2 := make(chan struct{})
go func() {
r.ServeHTTP(w2, newRequest("b"))
close(done2)
}()
waitProcessed(t, r.testProcessed, 1)
if got := pb.runCalls.Load(); got != 0 {
t.Errorf("b started in parallel: runCalls=%d want 0", got)
}
<-a.runStarted
a.markReady()
waitProcessed(t, r.testProcessed, 1) // swapDone(a) → b promoted, evicts a
<-pb.runStarted
pb.markReady()
for i, ch := range []chan struct{}{done1, done2} {
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("request %d did not complete", i)
}
}
if got := a.stopCalls.Load(); got != 1 {
t.Errorf("a.stopCalls=%d want 1 (b's swap must stop a)", got)
}
}
// TestMatrixSolver_TieBreakDefinitionOrder pins the solver's tie-break rule:
// when multiple candidate sets have equal eviction cost, the earlier-defined
// set wins.
func TestMatrixSolver_TieBreakDefinitionOrder(t *testing.T) {
expanded := []config.ExpandedSet{
{SetName: "first", DSL: "a & b", Models: []string{"a", "b"}},
{SetName: "second", DSL: "a & c", Models: []string{"a", "c"}},
}
s := newMatrixSolver(expanded, nil)
// No models running, request "a": both sets have cost 0 and contain a.
// Definition order: "first" wins.
result := s.Solve("a", nil)
if result.SetName != "first" {
t.Errorf("SetName=%q want %q", result.SetName, "first")
}
}
// TestMatrixSolver_EvictCostsPreferred verifies that higher evict costs steer
// the solver toward a cheaper set.
func TestMatrixSolver_EvictCostsPreferred(t *testing.T) {
// b is expensive to evict; c is cheap. Request "a" with both b and c
// running. The solver should pick the set that keeps b.
expanded := []config.ExpandedSet{
{SetName: "a_with_c", DSL: "a & c", Models: []string{"a", "c"}}, // would evict b (cost 10)
{SetName: "a_with_b", DSL: "a & b", Models: []string{"a", "b"}}, // would evict c (cost 1)
}
s := newMatrixSolver(expanded, map[string]int{"b": 10, "c": 1})
result := s.Solve("a", []string{"b", "c"})
if result.SetName != "a_with_b" {
t.Errorf("SetName=%q want %q (keep expensive b)", result.SetName, "a_with_b")
}
if len(result.Evict) != 1 || result.Evict[0] != "c" {
t.Errorf("Evict=%v want [c]", result.Evict)
}
}
+188
View File
@@ -0,0 +1,188 @@
package router
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
)
type peerMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type Peer struct {
cfg config.Config
logger *logmon.Monitor
peers map[string]*peerMember
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
inflight sync.WaitGroup
}
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
peers := cfg.Peers
modelMap := make(map[string]*peerMember)
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
for _, peerID := range peerIDs {
peer := peers[peerID]
peerTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
}
reverseProxy := &httputil.ReverseProxy{
Transport: peerTransport,
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(peer.ProxyURL)
r.Out.Host = r.Out.URL.Host
},
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
for _, modelID := range peer.Models {
if _, found := modelMap[modelID]; found {
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
modelMap[modelID] = pp
}
}
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
return &Peer{
cfg: cfg,
logger: logger,
peers: modelMap,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
}, nil
}
func (r *Peer) Handles(model string) bool {
_, ok := r.peers[model]
return ok
}
func (r *Peer) Shutdown(timeout time.Duration) error {
if !r.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("shutdown already in progress")
}
if timeout == 0 {
r.shutdownFn()
r.inflight.Wait()
return nil
}
done := make(chan struct{})
go func() {
r.inflight.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-time.After(timeout):
r.shutdownFn()
r.inflight.Wait()
return fmt.Errorf("peer shutdown timed out after %v", timeout)
}
}
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.shuttingDown.Load() {
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
return
}
r.inflight.Add(1)
defer r.inflight.Done()
data, err := shared.FetchContext(req, r.cfg)
if err != nil {
shared.SendError(w, req, err)
return
}
pp, found := r.peers[data.ModelID]
if !found {
r.logger.Warnf("peer model not found: %s", data.ModelID)
shared.SendError(w, req, ErrNoPeerModelFound)
return
}
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
if pp.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
req.Header.Set("x-api-key", pp.apiKey)
}
// Cancel the proxy request when the client disconnects or shutdown times out.
// AfterFunc links both parent contexts to our child without a goroutine leak.
ctx, cancel := context.WithCancel(context.Background())
stopReq := context.AfterFunc(req.Context(), cancel)
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
req = req.WithContext(ctx)
pp.reverseProxy.ServeHTTP(w, req)
stopShutdown()
stopReq()
cancel()
}
+612
View File
@@ -0,0 +1,612 @@
package router
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
)
var testLogger = logmon.NewWriter(os.Stdout)
func init() {
testLogger.SetLogLevel(logmon.LevelWarn)
}
func TestNewPeer_EmptyPeers(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
if pr == nil {
t.Fatal("expected non-nil Peer")
}
if len(pr.peers) != 0 {
t.Fatalf("expected empty peers map, got %d entries", len(pr.peers))
}
}
func TestNewPeer_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 2 {
t.Fatalf("expected 2 entries, got %d", len(pr.peers))
}
if _, ok := pr.peers["model-a"]; !ok {
t.Error("expected model-a to be mapped")
}
if _, ok := pr.peers["model-b"]; !ok {
t.Error("expected model-b to be mapped")
}
if _, ok := pr.peers["model-c"]; ok {
t.Error("expected model-c to not be mapped")
}
}
func TestNewPeer_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 4 {
t.Fatalf("expected 4 entries, got %d", len(pr.peers))
}
for _, m := range []string{"model-a", "model-b", "model-c", "model-d"} {
if _, ok := pr.peers[m]; !ok {
t.Errorf("expected %s to be mapped", m)
}
}
}
func TestNewPeer_DuplicateModel(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
if len(pr.peers) != 1 {
t.Fatalf("expected 1 entry for duplicate model, got %d", len(pr.peers))
}
if _, ok := pr.peers["duplicate-model"]; !ok {
t.Error("expected duplicate-model to be mapped")
}
}
func TestPeer_ServeHTTP_Success(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
if w.Body.String() != "response from peer" {
t.Errorf("expected 'response from peer', got %q", w.Body.String())
}
}
func TestPeer_ServeHTTP_ModelNotFoundInContext(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_PeerModelNotFound(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "nonexistent-model", ModelID: "nonexistent-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_ApiKeyInjection(t *testing.T) {
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if receivedAuthHeader != "Bearer secret-api-key" {
t.Errorf("expected 'Bearer secret-api-key', got %q", receivedAuthHeader)
}
}
func TestPeer_ServeHTTP_NoApiKey(t *testing.T) {
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "",
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if receivedAuthHeader != "" {
t.Errorf("expected no auth header, got %q", receivedAuthHeader)
}
}
func TestPeer_ServeHTTP_HostHeaderSet(t *testing.T) {
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if !strings.HasPrefix(receivedHost, "127.0.0.1:") {
t.Errorf("expected Host to start with '127.0.0.1:', got %q", receivedHost)
}
}
func TestPeer_ServeHTTP_SSEHeaderModification(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Header().Get("X-Accel-Buffering") != "no" {
t.Errorf("expected X-Accel-Buffering=no, got %q", w.Header().Get("X-Accel-Buffering"))
}
}
func TestPeer_ServeHTTP_ShutdownRejectsNewRequests(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "shutting down") {
t.Errorf("expected 'shutting down' in body, got %q", w.Body.String())
}
}
func TestPeer_ServeHTTP_WaitsForInflightDuringShutdown(t *testing.T) {
started := make(chan struct{})
released := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-released
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
}()
<-started
shutdownDone := make(chan error, 1)
go func() {
shutdownDone <- pr.Shutdown(500 * time.Millisecond)
}()
// Shutdown should be waiting on inflight. If it finished already something is wrong.
time.Sleep(100 * time.Millisecond)
select {
case err := <-shutdownDone:
t.Errorf("shutdown completed before inflight finished: %v", err)
default:
}
close(released)
wg.Wait()
select {
case err := <-shutdownDone:
if err != nil {
t.Errorf("shutdown errored after inflight completed: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("shutdown did not complete after inflight finished")
}
}
func TestPeer_ServeHTTP_ShutdownTimeoutCancelsInflight(t *testing.T) {
started := make(chan struct{})
released := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-released
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "test-model", ModelID: "test-model"}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
}()
<-started
err = pr.Shutdown(100 * time.Millisecond)
if err == nil {
t.Error("expected timeout error from shutdown")
}
close(released)
wg.Wait()
}
func TestPeer_ShutdownMultiple(t *testing.T) {
pr, err := NewPeer(config.Config{}, testLogger)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err != nil {
t.Fatal(err)
}
err = pr.Shutdown(0)
if err == nil {
t.Error("expected error on second shutdown")
}
if !strings.Contains(err.Error(), "already in progress") {
t.Errorf("expected 'already in progress', got %q", err.Error())
}
}
func TestPeer_ServeHTTP_ModelExtractedFromBody(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"extracted-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(`{"model":"extracted-model","prompt":"hello"}`)
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestPeer_ServeHTTP_ContextOverridesBodyModel(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"context-model"},
},
"peer2": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"body-model"},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
body := strings.NewReader(`{"model":"body-model","prompt":"hello"}`)
req := httptest.NewRequest("POST", "/v1/chat/completions", body)
req.Header.Set("Content-Type", "application/json")
*req = *req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: "context-model", ModelID: "context-model"}))
w := httptest.NewRecorder()
pr.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestNewPeer_CustomTimeouts(t *testing.T) {
proxyURL, _ := url.Parse("http://localhost:8080")
peers := config.PeerDictionaryConfig{
"test-peer": config.PeerConfig{
Proxy: "http://localhost:8080",
ProxyURL: proxyURL,
Models: []string{"model1"},
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 300,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
},
}
pr, err := NewPeer(config.Config{Peers: peers}, testLogger)
if err != nil {
t.Fatal(err)
}
member, ok := pr.peers["model1"]
if !ok {
t.Fatal("expected model1 to be mapped")
}
transport, ok := member.reverseProxy.Transport.(*http.Transport)
if !ok {
t.Fatal("expected Transport to be *http.Transport")
}
if transport.ResponseHeaderTimeout != 300*time.Second {
t.Errorf("expected ResponseHeaderTimeout=%v, got %v", 300*time.Second, transport.ResponseHeaderTimeout)
}
if transport.TLSHandshakeTimeout != 15*time.Second {
t.Errorf("expected TLSHandshakeTimeout=%v, got %v", 15*time.Second, transport.TLSHandshakeTimeout)
}
if transport.ExpectContinueTimeout != 2*time.Second {
t.Errorf("expected ExpectContinueTimeout=%v, got %v", 2*time.Second, transport.ExpectContinueTimeout)
}
if transport.IdleConnTimeout != 120*time.Second {
t.Errorf("expected IdleConnTimeout=%v, got %v", 120*time.Second, transport.IdleConnTimeout)
}
if !transport.ForceAttemptHTTP2 {
t.Error("expected ForceAttemptHTTP2 to be true")
}
}

Some files were not shown because too many files have changed in this diff Show More