Compare commits

...

130 Commits

Author SHA1 Message Date
Benson Wong 231e62291c proxy: fix matrix race and process stop bug (#677)
- matrix.go change logic to consider any proxy.Process not in
StateStopped or StateShutdown
- process.StopImmediately, and Stop() which called it had a subtle bug
where it only handled state transitions from StateReady to
StateStopping. StateStarting -> StateStopping was ignored completely.

fix: #670
2026-04-20 00:21:11 -07:00
Benson Wong 57ac666598 .github/workflows: tweak push ghcr conditional (#676) 2026-04-19 13:56:26 -07:00
Benson Wong 69728301f5 .github/workflows: add toggle for pushing unified images to github (#672)
Add ability to dispatch (manually run) unified container builds in github without push to ghcr.io.
2026-04-19 10:10:48 -07:00
Benson Wong c176fa70f1 docker/unified: add spirv-headers to fix vulkan build (#669) 2026-04-18 12:18:10 -07:00
Benson Wong 5e3c646829 proxy: compress captures with zstd (#668)
The previous captures were saved uncompressed in memory. In agentic
workflows there can be many turns with each request containing the
previous context in the body with a lot of redundant data. Use zstd to
compress the request and response data before keeping a copy of memory.

Results: 

- Average Percentage Saved: 73.19%
- Average Compression Factor: ~6.77:1
2026-04-17 23:29:37 -07:00
Benson Wong c3f0d43e6e proxy: fix race conditions during swap (#667)
I pointed Opus 4.7 (high effort) at proxy.ProcessGroup to identify any
race conditions in the swapping code. It found a race condition where
there is a small window in the fast path for routing a request to a
loaded model. There is a very small window where:

- model M1 is loaded and ready for requests
- a request, R1, for M1 comes in 
- a request, R2, for M2 comes in almost immediately after
- R1 acquires the lock, sees M1 is loaded (fast path), releases the lock
`[race window]` and the request is ready to be forwarded
- the race window occurs between the release of the lock and the request
being forwarded
  - the lock is released so requests can be handled concurrently 
- R2 comes in within the `[race window]`, acquires the lock, triggers a
model swap to M2. stopping M1
- R1 is forwarded to a model that is unloaded or in the process of
shutting down creating an error response

In deployed systems the race window is very small and doesn't happen
often. However with #635 and PR #656 I though this deserved a bit more
attention. It is not concluded that this race is the cause of #635 but
the race is likely to happen more often under sustained or high load.

AI Note: Opus 4.7 x-high effort took about an hour to write the original
patch. With the pattern discovered the fix to matrix.go was very quick.
GLM 5.1 using the previous established patterns was able to easily write
the fix for ProcessGroup.StopProcesses().

Supersedes: #656
Updates: #277, #635
2026-04-17 21:23:17 -07:00
Benson Wong f6cf9f5844 proxy: Refactor tests (#660)
- use YAML for test configurations
- remove most uses of simple-responder, opting to use
process.testHandler

Fixes #655
2026-04-16 22:47:42 -07:00
Benson Wong 121fd93ad8 Makefile: restore linux arm64 targets
Fix #641
2026-04-14 22:05:39 -07:00
Benson Wong 17233e9278 docs: update configuration.md for matrix 2026-04-14 22:01:03 -07:00
Benson Wong 4866d16c3e README.md: update to use matrix instead of groups 2026-04-14 21:57:49 -07:00
Benson Wong 35193f82f1 proxy: add swap matrix with solver-based model swapping (#646)
Add a new swap matrix to supersede groups for running concurrent models.
The matrix uses a solver that picks the lowest cost evictions to make a
requested model available. This simple approach along with a very basic
DSL grammar can enable very complex swapping scenarios.

- add DSL parser for set expressions with & (AND), | (OR), (), +ref
- add MatrixConfig structs, validation, and topological sort for +ref
- add MatrixSolver with cost-minimizing swap decisions
- add Matrix runtime integrating solver with Process lifecycle
- integrate matrix into ProxyManager with if-branches at all endpoints
- update config.example.yaml and config-schema.json with matrix schema
- config enforces groups XOR matrix (cannot use both)

fixes #643
2026-04-14 21:55:30 -07:00
Benson Wong 40e39f7a86 ui-svelte: fix security issues (#649) 2026-04-12 16:21:31 -07:00
Benson Wong a9d840ffd7 proxy,proxy/config: restore timeouts to pre PR 619 (#648)
Reset the default ResponseHeader timeout to 0 (no timeout) which was set
to 60 seconds in PR #619.

Fixes #647
2026-04-11 20:42:13 -07:00
Benson Wong 7b2b82777f docker/unified: derive rootless image from root container (#644)
Build the root image once, then derive the rootless variant from it
using a small inline Dockerfile that adds the non-root user and chowns
the writable directories. This halves the number of CI jobs (4 → 2) and
eliminates the redundant full CUDA compilation for the rootless variant.

- remove RUN_UID build arg from build-image.sh
- derive rootless image inline after root build completes
- collapse variant matrix out of unified-docker.yml
- push both root and rootless tags in a single CI job

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 22:59:54 -07:00
Benson Wong d87f0ce2c5 docker/unified: publish rootless image variant (#630) 2026-04-07 03:05:53 -07:00
Leoy 06bc6a614c proxy: preserve wall-clock duration in metrics (#629)
Keep request duration from being underreported when upstream timings
only cover part of the full request lifecycle.

- compare wall-clock and upstream timing durations
- keep token and throughput values from timings
- add regression coverage for underreported timings

fixes #602
2026-04-07 01:52:41 -07:00
Ron M a37b4866d8 proxy: add configurable HTTP timeouts for models and peers (#619)
Add configurable HTTP timeout settings to both models and peers to support installations that requires longer timeouts than the current hardcoded defaults.

Closes #618
2026-04-06 19:30:27 +08:00
Benson Wong 981910d734 ci: validate config.example.yaml against config-schema.json (#627)
Extend the existing config-schema workflow to also validate
config.example.yaml against config-schema.json using check-jsonschema.

- add config.example.yaml to PR and push path triggers
- install check-jsonschema via pip
- run validation of config.example.yaml against schema

https://claude.ai/code/session_01Y1oqwE6mwNs9UTJgZRgXtG

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-05 15:17:57 +08:00
Benson Wong a185efe37e docker: make CMAKE_CUDA_ARCHITECTURES configurable via build arg (#625)
Expose CMAKE_CUDA_ARCHITECTURES as a Docker build ARG so users can
customize CUDA architectures via --build-arg without editing the
Dockerfile.

- convert hardcoded ENV to ARG with default, feeding into ENV
- replace silent fallback defaults (:-) in scripts with :? guards
  to fail fast if the env var is missing
- add usage example to Dockerfile header

Follow up to: #624

https://claude.ai/code/session_01EWiUe7jNABX7Uz95dUGJqK

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-04 08:49:59 +08:00
Benson Wong 1dd1aadf93 docker/unified: add ik_llama.cpp to CUDA container (#620) 2026-04-03 15:16:30 +08:00
Benson Wong 955900972a add /sdapi to list of supported endpoints 2026-04-01 12:01:38 +08:00
Benson Wong c2c8cfaf81 docker/unified: build llama.cpp with static libraries (#616) 2026-04-01 03:38:07 +08:00
Benson Wong 1e440770ea ci: fix matrix exclude for scheduled docker workflow (#610) 2026-03-29 20:04:28 +09:00
Benson Wong c794273c83 docker/unified,.github: fix unified build (#606) 2026-03-27 10:31:12 +09:00
dependabot[bot] 6574a52cbb build(deps): bump picomatch from 4.0.3 to 4.0.4 in /ui-svelte (#605) 2026-03-26 22:28:24 +09:00
Benson Wong 8fabc75634 docker/unified: vulkan build fixes (#600)
multiple fixes to vulkan build: 

- use ubuntu 26.04 to be compatible with AMD 395+ (Strix halo) hardware
- add home directory in container 
- fix stable-diffusion install to actually enable vulkan

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-25 23:26:13 +09:00
Benson Wong e5e7391b6d .github,docker/unified: include vulkan build (#599)
Update docker/unified scripts to support building both cuda and vulkan unified images.
2026-03-25 06:58:28 +09:00
Benson Wong 2c282dccad .github,docker/unified: improve caching and fix bugs (#598)
- set up a GHA scheduled job to build the container nightly 
- enabling pushing a llama-swap:unified and a llama-swap:unified-Y-M-D
image to ghcr.io
- tidy up Dockerfile to use a non-root user and llama-swap as an entry
point
2026-03-23 22:24:40 +09:00
Benson Wong 916d13f5bd .github/workflows,docker/unified: add cuda based unified container (#597)
Add Docker build scripts for a unified cuda docker container with llama-server, stable-diffusion.cpp, whisper.cpp.
2026-03-22 21:11:54 +09:00
Benson Wong a3725e7d09 Update go.mod to 1.26.1 (#593) 2026-03-20 16:09:58 +09:00
Benson Wong 15bd55d3a9 proxy, ui-svelte: add /sdapi/v1 endpoint support (#587)
Add proxy routes for stable-diffusion.cpp's /sdapi/v1/txt2img,
/sdapi/v1/img2img, and /sdapi/v1/loras endpoints. POST endpoints
use proxyInferenceHandler (model in JSON body), GET /loras uses
proxyGETModelHandler (model in query param).

Update the image playground with a dual-mode UI supporting both
OpenAI and SDAPI backends. In SDAPI mode, loras are fetched first
to prime the server-side cache, and all txt2img parameters are
exposed (negative prompt, steps, cfg_scale, seed, batch_size,
clip_skip, sampler, scheduler, lora selection with multipliers).

- Add 3 sdapi route registrations in proxymanager.go
- Add sdApi.ts client with generateSdImage and fetchSdLoras
- Add SDAPI types (SdApiTxt2ImgRequest, SdApiResponse, etc.)
- Add /sdapi to vite dev proxy config
- Add backend tests for sdapi routing
- Support batch image display in gallery grid

https://claude.ai/code/session_0186MGX6NXdHVBTv2KH45fqn

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-19 22:08:31 +09:00
Benson Wong c3c258a55d proxy: fix metrics capture for v1/responses (#586)
properly parse anthropic compatible usage data from streaming responses.

closes: #577
2026-03-13 16:50:12 -07:00
Benson Wong 29a38fde0d ui-svelte: upgrade to vite 8 (#585)
Upgrade vite and related dependencies to take advantage of Vite 8's
improved build times via Rolldown and Oxc.

- vite: ^6.3.5 → ^8.0.0
- @sveltejs/vite-plugin-svelte: ^5.0.3 → ^7.0.0
- svelte: ^5.19.0 → ^5.46.4
- vite-plugin-compression2: ^2.4.0 → ^2.5.1
- vitest: ^4.0.18 → ^4.1.0

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-13 08:45:59 -07:00
tesuri d569681daa Change model sorting to natural order (#582)
Use natural sorting for model names.

Previously the model list was sorted lexicographically, which resulted
in unintuitive ordering when numbers were included in the name.

Example:

Before
qwen3.5:2B
qwen3.5:35B-3AB
qwen3.5:9B

After
qwen3.5:2B
qwen3.5:9B
qwen3.5:35B-3AB

This change sorts models using natural order so numeric parts are
compared numerically.
2026-03-12 07:49:34 -07:00
Benson Wong 24efdb76b1 config: add macro support for name and description fields (#578)
Extend macro substitution to the name and description fields of
ModelConfig, matching the behavior already present for cmd, proxy,
checkEndpoint, and filters.

- substitute global/model macros (including MODEL_ID) in name and
description
- substitute PORT macro in name and description when allocated
- validate no unknown macros remain in name and description after
substitution
- add tests for macro substitution, MODEL_ID, and unknown macro error
2026-03-10 08:27:05 -07:00
Benson Wong cc77139ff8 proxy,proxy/config: add global TTL feature (#554)
Add a new configuration parameter globalTTL that all models will
inherit. The default value is 0 which matches the currently
functionality to never automatically unload a model.

The model.ttl's default has changed to -1, which means use the global
TTL value. Any model.ttl >=0 is now value with 0 meaning never unload.
This allows a model to override a globalTTL > 0 and be configured to
never unload.

Fixes #459
Closes #512
2026-03-01 21:02:12 -08:00
Benson Wong 390a35bf93 ui-svelte: add copy button to markdown code blocks (#537)
Add a copy-to-clipboard button that appears on hover for each code block
rendered in the chat interface assistant messages.

- Svelte action `codeBlockCopy` injects a button into every `<pre>`
element
- MutationObserver reattaches buttons as streaming content arrives
- Button shows a check icon for 2 seconds after a successful copy
- Uses clipboard API with execCommand fallback for non-secure contexts
- CSS hides button by default and reveals it on pre:hover

https://claude.ai/code/session_01PTA5ao5YQuFAS6a9juLeZW

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-01 09:48:56 -08:00
pdscomp 181f71ca11 .github,docker: add cuda13 architecture support (#551)
Add `cuda13` as a supported build architecture, targeting the
`ghcr.io/ggml-org/llama.cpp:server-cuda13` upstream base image.

The `server-cuda13` image ships with CUDA 13 libraries, providing
improved performance on recent NVIDIA hardware compared to the existing
`server-cuda` (CUDA 12) image. Users with newer GPUs (e.g., RTX
50-series) benefit from reduced model load latency and higher token
throughput.

- Add `cuda13` to the allowed architectures list in
`docker/build-container.sh`
- Add `cuda13` to the CI matrix in `.github/workflows/containers.yml` so
the container is built and pushed automatically
2026-03-01 09:37:08 -08:00
Benson Wong 49546e2cf2 ui: fix text size svg 2026-02-27 23:47:52 -08:00
Benson Wong 2c078964f4 Update README with additional images
Added new images for model loading and real-time log streaming sections.
2026-02-27 23:45:40 -08:00
Benson Wong 175bb36fb1 Revise README description for clarity and detail
Updated description to clarify compatibility and usage.
2026-02-27 23:42:40 -08:00
Benson Wong aedb640471 Enhance web UI section in README
Updated README to enhance the description of the web interface and added details about features like token metrics, request inspection, model management, and real-time log streaming.
2026-02-27 23:40:31 -08:00
Benson Wong 2f377f6dc6 ui: add OGG audio format support to transcription playground (#544) 2026-02-26 19:48:19 -08:00
Benson Wong 64e4c79fc3 ui: add Rerank tab to playground (#536)
Add a new Rerank tab to the playground that lets users test /v1/rerank
endpoints. Supports a visual table editor and a JSON editor mode that
stay in sync when toggling between them.

- add rerankApi.ts with typed wrapper for /v1/rerank
- add RerankInterface.svelte with query input, sortable document table,
color-coded scores, auto-add row, cancel/clear, and token usage
- add rerankLoading store to playgroundActivity derived store
- register Rerank tab in Playground.svelte

Updates #481
2026-02-21 21:59:14 -08:00
Benson Wong 19fb5f35e9 proxy: implement setParamsByID filter (#535)
Add setParamsByID filter that applies different request parameters based
on the requested model ID, enabling per-alias behaviour for a single
loaded model.

- add SetParamsByID field to Filters struct and SanitizedSetParamsByID
method
- substitute ${MODEL_ID} and other macros in setParamsByID keys and
values
- validate no unknown macros remain in keys or values after substitution
- apply setParamsByID in proxyInferenceHandler after setParams (can
override it)
- update config-schema.json with setParamsByID definition
- update UI to show aliases and make them selectable in the Playground

closes #534
2026-02-19 22:21:10 -08:00
Benson Wong b45102bde8 ui: smart auto-scroll in LogPanel (#530)
Pause auto-scroll when the user scrolls up to review logs, and resume
when they scroll back to the bottom.

- add `userScrolledUp` state variable
- add `handleScroll` to detect scroll position with 40px threshold
- guard the auto-scroll effect with `!userScrolledUp`

closes #529
2026-02-18 19:47:37 -08:00
Brian Mendonca 1688bdd1e9 proxy, ui: add pending requests count to the main dashboard (#516)
add a real time counter of pending (inflight) requests to the UI.
2026-02-16 09:41:15 -08:00
Benson Wong d33d51fa75 .coderabbit.yaml,AGENTS.md: small tweaks 2026-02-15 21:31:30 -08:00
Benson Wong e3bf065574 ui: persist playground state across route navigation (#525)
- Keep Playground component mounted when navigating away, preserving
streaming/generating state
- Add animated gradient effect on Playground nav link when activity is
in progress
2026-02-15 21:30:52 -08:00
Benson Wong 3e52144058 ui-svelte: incremental rendering of chat messages in the Playground (#520)
add incremental rendering to Playground > Chat
2026-02-15 11:00:44 -08:00
Benson Wong d5e52d7d00 build: disable provenance attestations in container builds (#523)
## Summary
- Add `--provenance=false` to docker build commands in
`build-container.sh`
- BuildKit attestation manifests are stored as untagged images in GHCR,
and the `delete-untagged-containers` cleanup job deletes them, breaking
the manifest list and causing `manifest unknown` errors on pull
- ref: https://github.com/actions/delete-package-versions/issues/162
2026-02-14 10:23:08 -08:00
Benson Wong 17e5263a76 .github/workflows: fix expired token in publishing images (#522)
Fixes: #517
2026-02-14 10:06:05 -08:00
Benson Wong 8d6d949ec3 proxy: support timings for /infill from llama-server (#510)
fixes: #463
2026-02-07 17:16:27 -08:00
Benson Wong b5fde8eb6d proxy,ui-svelte: add request/response capturing (#508)
Add saving request and response headers and bodies that go through
llama-swap in memory.

- captureBuffer added to configuration. Captures are enabled by default.
- 5MB of memory is allocated for req/response captures in a ring buffer.
Setting captureBuffer to 0 will disable captures.
- UI elements to view captured data added to Activity page. Includes
some
QOL features like json formatting and recombining SSE chat streams
- capture saving is done at the byte level and has minimal impact on
llama-swap performance

Fixes #464 
Ref #503
2026-02-07 15:40:01 -08:00
Nuno 7eef5defb8 docs: add stable-diffusion.cpp references (#506)
Signed-off-by: rare-magma <rare-magma@posteo.eu>
2026-02-04 20:20:39 -08:00
Benson Wong bc01e6f539 build: add stable-diffusion server to musa and vulkan container images (#504)
Add sd-server from stable-diffusion.cpp docker image for 
vulkan and musa containers.

closes #450
2026-02-01 16:17:26 -08:00
Benson Wong 0462e3dc3f Reorganize UI controls and improve form interactions (#500)
Reorganizes control placement in the playground interfaces and
improves form interactions for better UX, particularly on mobile
devices.

## Key Changes

- **AudioInterface & ImageInterface**: Moved "Clear" buttons from the
top control bar into the action button group below the form inputs for
better visual hierarchy and logical grouping
- **ImageInterface**: 
- Added prompt clearing to the `clearImage()` function so the input
field is reset when clearing generated images
- Updated Clear button disabled state to also check if prompt is empty,
allowing users to clear an empty prompt
- Added responsive flex styling (`flex-1 md:flex-none`) to the Clear
button for better mobile layout
- **ExpandableTextarea**: 
- Imported `untrack` from Svelte to properly handle reactive
dependencies
- Wrapped `expandedValue.length` in `untrack()` to prevent unnecessary
reactivity when setting cursor position
- Improved button visibility on mobile by changing opacity from
`opacity-0` to `opacity-60` with `md:opacity-0` breakpoint, making the
expand button more discoverable on touch devices

## Implementation Details
The `untrack()` usage in ExpandableTextarea ensures that reading the
text length doesn't create a reactive dependency, preventing potential
infinite loops while still allowing the effect to run when `isExpanded`
changes.
2026-02-01 15:18:22 -08:00
Benson Wong 7b20fc011b Add path filters to CI workflows and create UI test workflow (#501)
* .github/workflows: add UI tests and path-filter Go CI

Add ui-tests.yml workflow to run svelte type checking and vitest
on push/PR to main when ui-svelte/ files change.

- Add path filters to go-ci.yml and go-ci-windows.yml to skip
  Go tests when only non-backend files change
- Filter on **/*.go, go.mod, go.sum, and Makefile

https://claude.ai/code/session_01E6acq54D8JjuE7pczxPGT7

* ui-svelte: remove unused declarations in SpeechInterface

Remove unused `generatedText` state and `clearAudio` function
that caused svelte-check errors.

https://claude.ai/code/session_01E6acq54D8JjuE7pczxPGT7

* .github/workflows: update Node.js to v24

Node 23 is end-of-life; bump to 24 in ui-tests.yml and release.yml.

https://claude.ai/code/session_01E6acq54D8JjuE7pczxPGT7

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-01 15:11:49 -08:00
Benson Wong 20738f3623 proxy,ui-svelte: replace old UI with svelte+playground
Replace the legacy React UI with the new Svelte-based one. Introduce a Playground in the UI to quickly test out text, image, text to speech and speech to text models behind llama-swap. 

Key Changes

New Svelte UI (ui-svelte/)

  - Multi-tab Playground with Chat, Image Generation, Audio Transcription, and Speech interfaces
  - Chat: message editing/regeneration, markdown rendering with LaTeX math support, image attachments, code syntax highlighting
  - Image: size selector, download/fullscreen viewing
  - Audio: transcription with peer support
  - Speech: voice caching with manual refresh, download button
  - Responsive mobile layout with collapsible navigation
  - XSS fixes and accessibility improvements

Proxy Improvements

  - Add gzip/brotli compression for UI static assets (proxy/ui_compress.go)
  - Add GET /v1/audio/voices?model={model} endpoint for voice listing
  - Add peer support for /v1/audio/transcriptions
2026-01-31 22:49:13 -08:00
Benson Wong cdea7d16bd proxy/config: skip env macros in YAML comment lines (#496)
Fix a bug where ${env.macro_not_exist} in comments would trigger a non-substituted macro error. 

fixes #495
2026-01-30 20:10:29 -08:00
Benson Wong 5de387dbf9 ui: fix node-tar vulnerability 2026-01-28 21:40:18 -08:00
Benson Wong 6f8e7ccb57 .github/workflows: switch release.yml to build ui-svelte 2026-01-28 21:39:10 -08:00
Benson Wong 4384315b44 ui-svelte: add Svelte port of React UI (#487)
Trying out svelte for the UI. The port was done by Claude Code on the iOS app w/ Opus 4.5. 

---

* ui: add Svelte port of React UI

Port the React-based UI to Svelte 5 with the following changes:

- Create new ui-svelte directory with complete Svelte 5 implementation
- Use Svelte stores instead of React contexts for state management
- Implement custom ResizablePanels component to replace react-resizable-panels
- Port all pages: LogViewer, Models, Activity
- Port all components: Header, ConnectionStatus, LogPanel, ModelsPanel, etc.
- Use svelte-spa-router for client-side routing
- Same build output directory (proxy/ui_dist) and base path (/ui/)
- Tailwind CSS 4 with same theme configuration

https://claude.ai/code/session_01F3xXLYsd62gePVSFv7aboP

* ui-svelte: simplify state management

- Remove redundant state syncing pattern in LogPanel and ModelsPanel
- Use store values directly with $ syntax instead of manual subscriptions
- Consolidate duplicate title sync logic in App.svelte
- Use existing syncTitleToDocument() from theme.ts

https://claude.ai/code/session_01F3xXLYsd62gePVSFv7aboP

* ui-svelte: use idiomatic Svelte 5 patterns

- Use $effect for document side effects (theme, title) instead of
  store subscriptions
- Use class: directive for active nav links in Header
- Remove SSR guards (unnecessary for client-only SPA)
- Remove leaked subscription in syncThemeToDocument
- Simplify theme.ts by removing sync functions

https://claude.ai/code/session_01F3xXLYsd62gePVSFv7aboP

* ui-svelte: fix build warnings and improve accessibility

Fix Svelte build warnings and add proper accessibility support
to interactive components.

- add aria-labels to buttons for screen readers
- implement keyboard navigation for resizable separator
- suppress intentional state initialization warnings
- update Makefile to use ui-svelte build directory
- add peer:true to package-lock.json dependencies

* ui-svelte: reorganize navigation and add log view toggle

Make Models the default landing page and add view mode toggle
to the Logs page with persistent state.

- set Models as default route at /
- move Logs to /logs route
- reorder navigation: Models, Activity, Logs
- add view toggle with three modes: Panels, Proxy only, Upstream only
- fix horizontal overflow with width constraints
2026-01-28 21:37:29 -08:00
Benson Wong 6439ab1515 ui: add peer:true in package-lock.json 2026-01-22 08:43:36 -08:00
dependabot[bot] f94226122c build(deps-dev): bump tar from 7.5.3 to 7.5.6 in /ui (#477)
Bumps [tar](https://github.com/isaacs/node-tar) from 7.5.3 to 7.5.6.
- [Release notes](https://github.com/isaacs/node-tar/releases)
- [Changelog](https://github.com/isaacs/node-tar/blob/main/CHANGELOG.md)
- [Commits](https://github.com/isaacs/node-tar/compare/v7.5.3...v7.5.6)

---
updated-dependencies:
- dependency-name: tar
  dependency-version: 7.5.6
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-21 22:55:02 -08:00
Ryan Voots 7493618fdc Add count_tokens api proxying (#476) 2026-01-20 09:34:42 -08:00
Benson Wong 205efd40a1 proxy: extend /running endpoint with additional process data (#474)
Extend the /running endpoint to return more details about running
processes beyond just model and state.

- add cmd field to show the command being executed
- add proxy field to show the proxy URL
- add ttl (UnloadAfter) for automatic unloading configuration
- add name and description for model metadata
- update tests to verify new fields are returned correctly

fixes #471
2026-01-19 17:37:00 -08:00
Benson Wong 14207f8492 ui: npm security update 2026-01-18 21:56:32 -08:00
Benson Wong 4e850c2834 config: refactor macro substitution in configuration (#470)
This commit simplifies substitution of environment variables into the configuration. There was a lot of repetitive code substituting ${env.VAR_NAME} into different fields after the configuration was parsed into a config.Config. This refactor uses a string substitution of env vars into the YAML config before it is fully parsed. This eliminates a lot of logic while maintaining backwards compatibility.
2026-01-18 21:52:34 -08:00
Benson Wong 75fced579e config: support macros in peer apiKey and filters (#469)
* config: support environment variable macros in peer apiKeys

Add ${env.VAR_NAME} substitution for peer apiKey fields, consistent
with existing env macro support for model fields and global apiKeys.

- Add env macro substitution for peers.{name}.apiKey in LoadConfigFromReader
- Add tests for peer apiKey env substitution
- Update config.example.yaml to show env macro usage

* config: support macros in peer apiKey and filters

Extend macro substitution to peer configuration fields:
- peers.{name}.apiKey supports both global macros and env macros
- peers.{name}.filters.stripParams supports both macro types
- peers.{name}.filters.setParams supports both macro types

Also renamed validateMetadataForUnknownMacros to validateNestedForUnknownMacros
for reuse across model metadata and peer filters validation.
2026-01-16 23:10:50 -08:00
Benson Wong b73f367f22 config-schema.json,config.example.yaml: Update examples and schema 2026-01-16 22:43:25 -08:00
Benson Wong 8f2137c72b config: support environment variable macros in apiKeys (#467)
Add substituteEnvMacros support for apiKeys configuration field,
allowing API keys to be loaded from environment variables using
the ${env.VAR_NAME} syntax.

- Apply env macro substitution before validation
- Add tests for env macro substitution in apiKeys
2026-01-16 22:41:14 -08:00
Benson Wong 124007cc98 config: add environment variable macros (#466)
* config: add environment variable macros

Add support for ${env.VAR_NAME} syntax to pull values from system
environment variables during config loading.

- env macros processed before regular macros (allows macros to reference env vars)
- works in cmd, cmdStop, proxy, checkEndpoint, filters.stripParams, metadata
- returns error if env var is not set
- add comprehensive tests

fixes #462

* docs: add env macro example to config.example.yaml
2026-01-16 22:25:20 -08:00
Benson Wong eb5bfff0b0 proxy: unify filtering for local models and peers
This unifies the filtering capabilities for models and peers

- stripParams: removes params in the request
- setParams: sets params in the request

fixes #453
2026-01-15 18:59:43 -08:00
Benson Wong 3edb180c08 ci: free up disk space before ROCm container build (#460) 2026-01-14 22:03:42 -08:00
Benson Wong 66d555e625 Improve container build reliability (#457)
* docker: add .env usage in build-container.sh
* .github,docker: add rocm, improve logging
* .github,CLAUDE.md: fix workflow and update guidelines

Update containers workflow to only push images when triggered
manually or on schedule, not on workflow file changes.

- add push trigger for workflow file changes in containers.yml
- update push condition to skip on regular push events
- update CLAUDE.md commit message guidelines

* docker: remove comma in build-container.sh

* .github,docker: improve container build workflow

Add pagination support for fetching llama.cpp tags and improve debugging.

- add build-container.sh to workflow trigger paths
- implement fetch_llama_tag() with pagination support
- replace .env with local testing instructions
- add DEBUG_ABORT_BUILD flag for testing
2026-01-10 22:14:33 -08:00
Benson Wong 4f863fd9fc CLAUDE.md: tweak instructions 2026-01-09 21:42:06 -08:00
Benson Wong 267c030457 ui: update react-router-dom to 7.12.0 (#456)
Update react-router-dom from 7.6.2 to 7.12.0 to address security vulnerability.

- Updated dependency in package.json
- Regenerated package-lock.json
- Verified build passes successfully
- Confirmed 0 vulnerabilities with npm audit

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-08 16:13:09 -08:00
Benson Wong c19309fe7e CLAUDE.md: small instruction tweaks 2026-01-07 21:34:23 -08:00
Benson Wong 4413881b2d proxy: actually add /v1/responses endpoint (#449)
ref: #448
2026-01-01 13:35:45 -08:00
Benson Wong 8df5e8563b proxy: add /v1/responses and /v1/audio/voices endpoints (#448)
Updates #433
Fixes #442 #226
2026-01-01 12:52:12 -08:00
Benson Wong 7931212d3e proxy: add v1/images/edits API endpoint (#447)
Updates #433
2026-01-01 12:43:06 -08:00
Benson Wong 3dc36032fb proxy: skip very slow tests in -short test mode (#446)
* proxy: skip very slow tests in -short test mode
* CLAUDE.md: update testing instructions
2025-12-31 14:08:56 -08:00
Benson Wong addb98646f proxy: add support for basic authorization (#445)
Fixes #444 where the UI with api keys did not work. The choice to use
http basic authorization is for simple, automatic browser support. No
changes to the UI were necessary. Just use an API key as the password,
no user name is required.
2025-12-31 13:42:35 -08:00
Benson Wong 37d74efc2d proxy: add /v1/images/generations (#443)
Add support for the /v1/images/generations endpoint

Updates #433
Closes #191
2025-12-30 21:04:58 -08:00
Benson Wong 22e098ac8b Add Peer Model Support (#438)
This PR allows a single llama-swap to be the central proxy for models served by other inference servers. The peer servers can be another llama-swap or any API that supports the /v1/* inference endpoint.

Updates: #433, #299
Closes: #296
2025-12-27 20:18:06 -08:00
Benson Wong 9864f9f517 .coderabbit.yaml: disable annoying features 2025-12-23 23:53:06 -08:00
Benson Wong 53b32f3601 proxy: add API key support (#436)
Add configuration support for api keys that are enforced by llama-swap. Keys are stripped before sending them to upstream servers. 

Updates: #433, #50 and #251
2025-12-23 23:39:33 -08:00
Benson Wong 565c44766d config,proxy: add new configuration logToStdout (#432)
The new logToStdout option controls what is logged to stdout. The
default has been changed to just the proxy logs, which contain swap and
http request logs.

There are four supported settings: none, proxy, upstream, both. The
"both" setting is the legacy setting where everything was spewed to
stdout.
2025-12-21 22:23:31 -08:00
Benson Wong e6a9e210ba proxy: fix path bug in /logs/stream/{model_id} (#431)
A {model_id} containing a forward slash trips up gin's path param
parsing. This updates /logs/stream to work like /upstream where the
model_id is built up in parts and searched for in the configuration.

Updates #421
2025-12-21 21:47:14 -08:00
Benson Wong d3f329f924 proxy: Improve logging performance and allow separate log streaming (#421)
Replace container/ring.Ring with a custom circularBuffer that uses a
single contiguous []byte slice. This fixes the original implementation
which created 10,240 ring elements instead of 10KB of storage.

GetHistory is now 139x faster (145μs → 1μs) and uses 117x less memory
(1.2MB → 10KB). Allocations reduced from 2 to 1 per write operation.

Create a LogMonitor per proxy.Process, replacing the usage
of a shared one. The buffer in LogMonitor is lazy allocated on the first
call to Write and freed when the Process is stopped. This reduces
unnecessary memory usage when a model is not active.

The /logs/stream/{model_id} endpoint was added to stream logs from a
specific process.
2025-12-18 21:49:25 -08:00
Benson Wong 98879b38c1 docker: add /app to $PATH (#424)
Make it so llama-server can be called directly instead of with the full
path at /app/llama-server.

Fixes #423
Ref: #233
2025-12-06 22:58:29 -08:00
Benson Wong 7b3b0f5eae move header images around [skip ci] 2025-12-02 19:40:42 -08:00
Benson Wong 021ccceef1 README: update hero image 2025-12-02 19:37:03 -08:00
Benson Wong f03871c50a Update README.md
- add supported anthropic API 
- add example for docker hot reload support
2025-12-02 19:03:01 -08:00
Ryan Steed dc00d17abe docs: add documentation for non-root container images and security considerations (#416)
* docs: add documentation for non-root container images and security considerations
* docs: move container security section to dedicated file and update README links
2025-12-02 08:52:26 -08:00
Benson Wong dea98733c3 proxy: extract metrics for v1/messages (#419) 2025-11-29 23:51:20 -08:00
Benson Wong bccce5fa19 go.mod,ui/package-lock.json: dependency and security updates (#418) 2025-11-29 22:27:22 -08:00
Benson Wong c968da1b73 proxy: add support for anthropic v1/messages api (#417)
* proxy: add support for anthropic v1/messages api
* proxy: restrict loading message to /v1/chat/completions
2025-11-29 22:09:07 -08:00
Ryan Steed a883d68d4f feat: Add support for custom llama.cpp base image and forked llama-swap repositories (#396)
* feat: Add support for custom llama.cpp base image and forked llama-swap repositories

- Introduce BASE_LLAMACPP_IMAGE env var to customize llama.cpp base image
- Introduce LS_REPO env var to customize llama-swap source
- Use GITHUB_REPOSITORY env var to automatically detect forked repos
- Update container tagging to use dynamic repo paths
- Pass build args for BASE_IMAGE and LS_REPO to Containerfile
- Enable flexible release downloads from forked repositories

* chore: quote entire curl options, appease coderabbitai
2025-11-29 20:59:15 -08:00
Ryan Steed b1dec8b735 docker: build both root and non-root container images (#412)
Change the user back to root for containers. Additionally, built a "non-root" labeled container for users who wish to have the additional security of running llama-swap as a lower privileged user.
2025-11-25 10:44:13 -08:00
Nikesh Parajuli 06523d8c1e feat: add platform-specific process attributes support (#411)
Fixes issues on Windows showing new windows for every process llama-swap spawns.
2025-11-24 21:39:56 -08:00
Ryan Steed 86e9b93c37 proxy,ui: add version endpoint and display version info in UI (#395)
- Add /api/version endpoint to ProxyManager that returns build date, commit hash, and version
- Implement SetVersion method to configure version info in ProxyManager
- Add version info fetching to APIProvider and display in ConnectionStatus component
- Include version info in UI context and update dependencies
- Add tests for version endpoint functionality
2025-11-17 10:43:47 -08:00
Ryan Steed 3acace810f proxy: add configurable logging timestamp format (#401)
introduces a new configuration option logTimeFormat that allows customizing the timestamp in log messages using golang's built in time format constants. The default remains no timestamp.
2025-11-16 10:21:59 -08:00
Ryan Steed 554d29e87d feat: enhance model listing to include aliases (#400)
introduce includeAliasesInList as a new configuration setting (default false) that includes aliases in v1/models

Fixes #399
2025-11-15 14:35:26 -08:00
Benson Wong 3567b7df08 Update image in README.md for web UI section 2025-11-08 15:29:37 -08:00
Benson Wong 38738525c9 config.example.yaml: add modeline for schema validation 2025-11-08 15:08:55 -08:00
Benson Wong c0fc858193 Add configuration file JSON schema (#393)
* add json schema for configuration
* add GH action to validate schema
2025-11-08 15:04:14 -08:00
Benson Wong b429349e8a add /ui/ to wol-proxy polling (#388) 2025-11-08 14:16:12 -08:00
Ryan Steed eab2efd7b5 feat: improve llama.cpp base image tag for cpu (#391)
Refactor the container build script to resolve llama.cpp base image for CPU, also tag these builds accordingly.

- For CPU containers, now fetch the latest 'server' tagged llama.cpp image instead of using a generic 'server' tag
- Cleans up the docker build command to use dynamic BASE_TAG variable
- Maintains existing push functionality for built images
2025-11-08 09:56:49 -08:00
Benson Wong 6aedbe121a cmd/wol-proxy: show a loading page for / (#381)
When requesting / wol-proxy will show a loading page that polls /status
every second. When the upstream server is ready the loading page will
refresh causing the actual root page to be displayed
2025-11-03 19:37:06 -08:00
Ryan Steed b24467ab89 fix: update containerfile user/group management commands (#379)
- Replace `addgroup` with `groupadd` for system group creation
- Replace `adduser` with `useradd` for system user creation
- Maintain same functionality while using more standard POSIX commands
2025-11-03 17:17:40 -05:00
Benson Wong 12b69fb718 proxy: recover from panic in Process.statusUpdate (#378)
Process.statusUpdate() panics when it can not write data, usually from a
client disconnect. Since it runs in a goroutine and did not have a
recover() the result was a crash.

ref: https://github.com/mostlygeek/llama-swap/discussions/326#discussioncomment-14856197
2025-11-03 05:30:09 -08:00
Ryan Steed f91a8b2462 refactor: update Containerfile to support non-root user execution and improve security (#368)
Set default container user/group to lower privilege app user 

* refactor: update Containerfile to support non-root user execution and improve security

- Updated LS_VER argument from 89 to 170 to use the latest version
- Added UID/GID arguments with default values of 0 (root) for backward compatibility
- Added USER_HOME environment variable set to /root
- Implemented conditional user/group creation logic that only runs when UID/GID are not 0
- Created necessary directory structure with proper ownership using mkdir and chown commands
- Switched to non-root user execution for improved security posture
- Updated COPY instruction to use --chown flag for proper file ownership

* chore: update containerfile to use non-root user with proper UID/GID

- Changed default UID and GID from 0 (root) to 10001 for security best practices
- Updated USER_HOME from /root to /app to avoid running as root user
2025-10-31 17:01:04 -07:00
Benson Wong a89b803d4a Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response.

Fixes: #366
2025-10-29 00:09:39 -07:00
Benson Wong f852689104 proxy: add panic recovery to Process.ProxyRequest (#363)
Switching to use httputil.ReverseProxy in #342 introduced a possible
panic if a client disconnects while streaming the body. Since llama-swap
does not use http.Server the recover() is not automatically there.

- introduce a recover() in Process.ProxyRequest to recover and log the
  event
- add TestProcess_ReverseProxyPanicIsHandled to reproduce and test the
  fix

fixes: #362
2025-10-25 20:40:05 -07:00
Benson Wong e250e71e59 Include metrics from upstream chat requests (#361)
* proxy: refactor metrics recording

- remove metrics_middleware.go as this wrapper is no longer needed. This
  also eliminiates double body parsing for the modelID
- move metrics parsing to be part of MetricsMonitor
- refactor how metrics are recording in ProxyManager
- add MetricsMonitor tests
- improve mem efficiency of processStreamingResponse
- add benchmarks for MetricsMonitor.addMetrics
- proxy: refactor MetricsMonitor to be more safe handling errors
2025-10-25 17:38:18 -07:00
Benson Wong d18dc26d01 cmd/wol-proxy: tweak logs to show what is causing wake ups (#356)
fix the extra wake ups being caused by wol-proxy

* cmd/wol-proxy: tweak logs to show what is causing wake ups
* cmd/wol-proxy: add skip wakeup
* cmd/wol-proxy: replace ticker with SSE connection
* cmd/wol-proxy: increase scanner buffer size
* cmd/wol-proxy: improve failure tracking
2025-10-25 11:04:31 -07:00
Benson Wong 8357714421 ui: fix avg token/sec calculation on models page (#357)
* ui: use percentiles for token stats
* ui: add histogram of metrics
* update vite to remove security warnings

fixes #355
2025-10-23 22:22:24 -07:00
Benson Wong c07179d6e2 cmd/wol-proxy: add wol-proxy (#352)
add a wake-on-lan proxy for llama-swap. When the target llama-swap server is unreachable it will send hold a request, send a WoL packet and proxy the request when llama-swap is available.
2025-10-20 20:55:02 -07:00
Benson Wong 7ff50631e0 Update README for setup instructions clarity [skip ci] 2025-10-19 14:55:23 -07:00
Benson Wong 9fc0431531 Clean up and Documentation (#347) [skip ci]
* cmd,misc: move misc binaries to cmd/
* docs: add docs and move examples/ there
* misc: remove unused misc/assets dir
* docs: add configuration.md
* update README with better structure

Updates: #334
2025-10-19 14:53:13 -07:00
David Wen Riccardi-Zhu 6516532568 Add optional TLS support (#340)
* Add optional TLS support

Introduce HTTPS support with net/http Server.ListenAndServeTLS.

This should enable the option of serving via HTTPS without a reverse
proxy.

Add two flags:
- tls-cert-file (path to the TLS certificate file)
- tls-key-file (path to the TLS private key file)

Both flags must be supplied together; otherwise exit with error.

If both flags are present, call srv.ListenAndServeTLS.
If not, fall back to the existing srv.ListenAndServe (HTTP); no changes
to existing non‑TLS behavior.
2025-10-15 19:29:02 -07:00
David Wen Riccardi-Zhu d58a8b85bf Refactor to use httputil.ReverseProxy (#342)
* Refactor to use httputil.ReverseProxy

Refactor manual HTTP proxying logic in Process.ProxyRequest to use the standard
library's httputil.ReverseProxy.

* Refactor TestProcess_ForceStopWithKill test

Update to handle behavior with httputil.ReverseProxy.

* Fix gin interface conversion panic
2025-10-13 16:47:04 -07:00
Benson Wong caf9e98b1e Fix race conditions in proxy.Process (#349)
- Fix data races found in proxy.Process by go's race detector. 
- Add data race detection to the CI tests. 

Fixes #348
2025-10-13 16:42:49 -07:00
Benson Wong 539278343b ui: tweak vertical space for mobile (#343) 2025-10-10 10:05:36 -07:00
Benson Wong 00b738cd0f Add Macro-In-Macro Support (#337)
Add full macro-in-macro support so any user defined macro can contain another one as long as it was previously declared in the configuration file.

Fixes #336
Supercedes #335
2025-10-06 22:57:15 -07:00
Benson Wong 70930e4e91 proxy: add support for user defined metadata in model configs (#333)
Changes: 

- add Metadata key to ModelConfig
- include metadata in /v1/models under meta.llamaswap key
- add recursive macro substitution into Metadata
- change macros at global and model level to be any scalar type

Note: 

This is the first mostly AI generated change to llama-swap. See #333 for notes about the workflow and approach to AI going forward.
2025-10-04 19:56:41 -07:00
Benson Wong 1f6179110c proxy/config: add model level macros (#330)
* proxy/config: add model level macros

Add macros to model configuration. Model macros override macros that are
defined at the global configuration level. They follow the same naming
and value rules as the global macros.

* proxy/config: fix bug with macro reserved name checking

The PORT reserved name was not properly checked

* proxy/config: add tests around model.filters.stripParams

- add check that model.filters.stripParams has no invalid macros
- renamed strip_params to stripParams for camel case consistency
- add legacy code compatibility so  model.filters.strip_params continues to work

* proxy/config: add duplicate removal to model.filters.stripParams

* clean up some doc nits
2025-09-28 23:32:52 -07:00
Benson Wong 216c40b951 proxy/config: create config package and migrate configuration (#329)
* proxy/config: create config package and migrate configuration

The configuration is become more complex as llama-swap adds more
advanced features. This commit moves config to its own package so it can
be developed independently of the proxy package.

Additionally, enforcing a public API for a configuration will allow
downstream usage to be more decoupled.
2025-09-28 16:50:06 -07:00
179 changed files with 26169 additions and 7587 deletions
+8 -1
View File
@@ -4,12 +4,19 @@ early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
high_level_summary: false
poem: false
review_status: true
collapse_walkthrough: false
sequence_diagrams: false
finishing_touches:
docstrings:
enabled: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
issue_enrichment:
planning:
enabled: false
+56
View File
@@ -0,0 +1,56 @@
name: Validate JSON Schema
on:
pull_request:
paths:
- "config-schema.json"
- "config.example.yaml"
- ".github/workflows/config-schema.yml"
push:
branches:
- main
paths:
- "config-schema.json"
- "config.example.yaml"
- ".github/workflows/config-schema.yml"
workflow_dispatch:
jobs:
validate-schema:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Validate JSON Schema
run: |
# Check if the file is valid JSON
if ! jq empty config-schema.json 2>/dev/null; then
echo "Error: config-schema.json is not valid JSON"
exit 1
fi
# Validate that it's a valid JSON Schema
# Check for required $schema field
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
echo "Warning: config-schema.json should have a \$schema field"
fi
# Check that it has either properties or definitions
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
fi
echo "✓ config-schema.json is valid"
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Install check-jsonschema
run: pip install check-jsonschema
- name: Validate config.example.yaml against schema
run: check-jsonschema --schemafile config-schema.json config.example.yaml
+29 -2
View File
@@ -10,17 +10,44 @@ on:
# Allows manual triggering of the workflow
workflow_dispatch:
# Run on workflow file changes (without pushing)
push:
paths:
- '.github/workflows/containers.yml'
- 'docker/build-container.sh'
- 'docker/*.Containerfile'
# grant permissions on GITHUB_TOKEN to publish packages
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
permissions:
contents: read
packages: write
id-token: write
jobs:
build-and-push:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [intel, cuda, vulkan, cpu, musa]
platform: [intel, cuda, cuda13, vulkan, cpu, musa, rocm]
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Free up disk space
if: matrix.platform == 'rocm'
run: |
echo "Before cleanup:"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
echo "After cleanup:"
df -h
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
@@ -31,7 +58,7 @@ jobs:
- name: Run build-container
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} true
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
+18 -2
View File
@@ -3,9 +3,25 @@ name: Windows CI
on:
push:
branches: [ "main" ]
# only run when backend source changes
# cmd/ is excluded because it contains utilities without tests
paths:
- '**/*.go'
- '!cmd/**'
- 'go.mod'
- 'go.sum'
- 'Makefile'
- '.github/workflows/go-ci-windows.yml'
pull_request:
branches: [ "main" ]
paths:
- '**/*.go'
- '!cmd/**'
- 'go.mod'
- 'go.sum'
- 'Makefile'
- '.github/workflows/go-ci-windows.yml'
# Allows manual triggering of the workflow
workflow_dispatch:
@@ -28,7 +44,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
@@ -43,7 +59,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- name: Test all
shell: bash
+50 -35
View File
@@ -2,53 +2,68 @@ name: Linux CI
on:
push:
branches: [ "main" ]
branches: ["main"]
# only run when backend source changes
# cmd/ is excluded because it contains utilities without tests
paths:
- "**/*.go"
- "!cmd/**"
- "go.mod"
- "go.sum"
- "Makefile"
- ".github/workflows/go-ci.yml"
pull_request:
branches: [ "main" ]
branches: ["main"]
paths:
- "**/*.go"
- "!cmd/**"
- "go.mod"
- "go.sum"
- "Makefile"
- ".github/workflows/go-ci.yml"
# Allows manual triggering of the workflow
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version-file: go.mod
# Only run in this linux based runner
- name: Check Formatting
run: |
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
gofmt -l . | grep -v 'event/.*_test.go'
exit 1
fi
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
# Only run in this linux based runner
- name: Check Formatting
run: |
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
gofmt -l . | grep -v 'event/.*_test.go'
exit 1
fi
# cache simple-responder to save the build time
- name: Restore Simple Responder
id: restore-simple-responder
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
# necessary for testing proxy/Process swapping
- name: Create simple-responder
run: make simple-responder
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Save Simple Responder
# nothing new to save ... skip this step
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
id: save-simple-responder
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- name: Test all
run: make test-all
- name: Test all
run: make test-all
+11 -16
View File
@@ -3,13 +3,13 @@ name: goreleaser
on:
push:
tags:
- '*'
- "*"
# Allows manual triggering of the workflow
workflow_dispatch:
inputs:
tag:
description: 'Tag version to release (e.g. v144)'
description: "Tag version to release (e.g. v144)"
required: true
permissions:
@@ -19,35 +19,30 @@ jobs:
goreleaser:
runs-on: ubuntu-latest
steps:
-
name: Checkout
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
-
name: Set up Go
- name: Set up Go
uses: actions/setup-go@v5
-
name: Set up Node.js
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '23'
-
name: Install dependencies and build UI
node-version: "24"
- name: Install dependencies and build UI
run: |
cd ui
cd ui-svelte
npm ci
npm run build
-
name: Run GoReleaser
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
# either 'goreleaser' (default) or 'goreleaser-pro'
distribution: goreleaser
# 'latest', 'nightly', or a semver
version: '~> v2'
version: "~> v2"
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -76,4 +71,4 @@ jobs:
"release": {
"tag_name": "${{ steps.tag.outputs.tag }}"
}
}
}
+42
View File
@@ -0,0 +1,42 @@
name: UI Tests
on:
push:
branches: [ "main" ]
paths:
- 'ui-svelte/**'
- '.github/workflows/ui-tests.yml'
pull_request:
branches: [ "main" ]
paths:
- 'ui-svelte/**'
- '.github/workflows/ui-tests.yml'
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ui-svelte
steps:
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '24'
cache: 'npm'
cache-dependency-path: ui-svelte/package-lock.json
- name: Install dependencies
run: npm ci
- name: Type check
run: npm run check
- name: Run tests
run: npm test
+136
View File
@@ -0,0 +1,136 @@
name: Build Unified Docker Image
on:
schedule:
- cron: "37 5 * * *"
workflow_dispatch:
inputs:
llama_cpp_ref:
description: "llama.cpp commit hash, tag, or branch"
required: false
default: "master"
whisper_ref:
description: "whisper.cpp commit hash, tag, or branch"
required: false
default: "master"
sd_ref:
description: "stable-diffusion.cpp commit hash, tag, or branch"
required: false
default: "master"
ik_llama_ref:
description: "ik_llama.cpp commit hash, tag, or branch (CUDA only)"
required: false
default: "main"
llama_swap_version:
description: "llama-swap version (e.g. v198, latest, main)"
required: false
default: "main"
build_cuda:
description: "Build CUDA image"
type: boolean
required: false
default: true
build_vulkan:
description: "Build Vulkan image"
type: boolean
required: false
default: true
push_to_ghcr:
description: "Push images to ghcr.io"
type: boolean
required: false
default: true
permissions:
contents: read
packages: write
jobs:
setup:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
run: |
backends=()
# schedule uses defaults (build both); workflow_dispatch respects inputs
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_cuda }}" == "true" ]]; then
backends+=("cuda")
fi
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_vulkan }}" == "true" ]]; then
backends+=("vulkan")
fi
matrix=$(printf '%s\n' "${backends[@]}" | jq -R . | jq -sc .)
echo "matrix=$matrix" >> $GITHUB_OUTPUT
build:
needs: setup
if: ${{ needs.setup.outputs.matrix != '[]' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Free up disk space
run: |
echo "Before cleanup:"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
echo "After cleanup:"
df -h
# On GitHub Actions runners, create a fresh builder.
# When running locally under act, skip this and reuse the existing
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
- name: Set up Docker Buildx
if: ${{ !env.ACT }}
uses: docker/setup-buildx-action@v3
- name: Log in to GitHub Container Registry
if: ${{ !env.ACT }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build unified Docker image (${{ matrix.backend }})
env:
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
SD_REF: ${{ inputs.sd_ref || 'master' }}
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}
# When running under act, use the local builder that has warm ccache.
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
# created by setup-buildx-action above.
BUILDX_BUILDER: ${{ env.ACT == 'true' && 'llama-swap-builder' || '' }}
run: |
chmod +x docker/unified/build-image.sh
docker/unified/build-image.sh --${{ matrix.backend }}
- name: Push to GitHub Container Registry
if: ${{ !env.ACT && inputs.push_to_ghcr == true }}
run: |
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
DATE_TAG=$(date -u +%Y-%m-%d)
docker push "${BASE_TAG}"
docker tag "${BASE_TAG}" "${BASE_TAG}-${DATE_TAG}"
docker push "${BASE_TAG}-${DATE_TAG}"
ROOTLESS_TAG="${BASE_TAG}-rootless"
docker push "${ROOTLESS_TAG}"
docker tag "${ROOTLESS_TAG}" "${ROOTLESS_TAG}-${DATE_TAG}"
docker push "${ROOTLESS_TAG}-${DATE_TAG}"
+51
View File
@@ -0,0 +1,51 @@
## Project Description:
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
## Tech stack
- golang
- typescript, vite and svelt5 for UI (located in ui/)
## Workflow Tasks
- when summarizing changes only include details that require further action
- just say "Done." when there is no further action
- use the github CLI `gh` to create pull requests and work with github
- Rules for creating pull requests:
- keep them short and focused on changes.
- never include a test plan
- write the summary using the same style rules as commit message
## Testing
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests.
### Commit message example format:
```
proxy: add new feature
Add new feature that implements functionality X and Y.
- key change 1
- key change 2
- key change 3
fixes #123
```
## Code Reviews
- use three levels High, Medium, Low severity
- label each discovered issue with a label like H1, M2, L3 respectively
- High severity are must fix issues (security, race conditions, critical bugs)
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
- Low severity are nice to have changes and nits
- Include a suggestion with each discovered item
- Limit your code review to three items with the highest priority first
- Double check your discovered items and recommended remediations
+1
View File
@@ -0,0 +1 @@
@AGENTS.md
+29 -11
View File
@@ -23,18 +23,24 @@ proxy/ui_dist/placeholder.txt:
mkdir -p proxy/ui_dist
touch $@
test: proxy/ui_dist/placeholder.txt
go test -short -v -count=1 ./proxy
# use cached test results while developing
test-dev: proxy/ui_dist/placeholder.txt
go test -short ./proxy/...
staticcheck ./proxy/... || true
test: proxy/ui_dist/placeholder.txt
go test -short -count=1 ./proxy/...
# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt
go test -v -count=1 ./proxy
go test -race -count=1 ./proxy/...
ui/node_modules:
cd ui && npm install
cd ui-svelte && npm install
# build react UI
ui: ui/node_modules
cd ui && npm run build
cd ui-svelte && npm run build
# Build OSX binary
mac: ui
@@ -42,9 +48,14 @@ mac: ui
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary
linux: ui
@echo "Building Linux binary..."
linux: linux-arm64 linux-amd64
linux-amd64: ui
@echo "Building Linux AMD64 binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
linux-arm64: ui
@echo "Building Linux ARM64 binary..."
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary
@@ -55,12 +66,12 @@ windows: ui
# for testing proxy.Process
simple-responder:
@echo "Building simple responder"
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
simple-responder-windows:
@echo "Building simple responder for windows"
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
# Ensure build directory exists
$(BUILD_DIR):
@@ -80,5 +91,12 @@ release:
echo "tagging new version: $$new_tag"; \
git tag "$$new_tag";
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
# Phony targets
.PHONY: all clean ui mac linux windows simple-responder
.PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
.PHONE: linux linux-arm64 linux-amd64
+188 -162
View File
@@ -1,77 +1,204 @@
![llama-swap header image](header2.png)
![llama-swap header image](docs/assets/hero3.webp)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary, a provided docker images or Homebrew.
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
## Features:
- ✅ Easy to deploy: single binary with no dependencies
- ✅ Easy to config: single yaml file
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
- ✅ On-demand model switching
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
- future proof, upgrade your inference servers at any time.
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/responses`
- `v1/embeddings`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-server (llama.cpp) supported endpoints:
- `v1/audio/voices`
- `v1/images/generations`
- `v1/images/edits`
- ✅ Anthropic API supported endpoints:
- `v1/messages`
- `v1/messages/count_tokens`
- ✅ llama-server (llama.cpp) supported endpoints
- `v1/rerank`, `v1/reranking`, `/rerank`
- `/infill` - for code infilling
- `/completion` - for completion endpoint
-llama-swap custom API endpoints
-SDAPI via [stable-diffusion.cpp's server](https://github.com/leejet/stable-diffusion.cpp/tree/master/examples/server)
- `/sdapi/v1/txt2img`
- `/sdapi/v1/img2img`
- `/sdapi/v1/loras` - requires `model` in request body to fetch the correct loras
- ✅ llama-swap API
- `/ui` - web UI
- `/log` - remote log monitoring
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring
- `/health` - just returns "OK"
-Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
-Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Reliable Docker and Podman support using `cmd` and `cmdStop` together
- ✅ Full control over server settings per model
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
-API Key support - define keys to restrict access to API endpoints
-Customizable
- Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
- Automatic unloading of models after timeout by setting a `ttl`
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
### Web UI
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
View detailed token metrics:
<img width="1111" height="515" alt="image" src="https://github.com/user-attachments/assets/64bfb280-d7a3-4126-971a-a128fd40410c" />
Inspect request and responses:
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
Manually load and unload models:
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
Real time log streaming:
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
## Installation
llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. WinGet
4. From release binaries
5. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
# run with a custom configuration and models directory
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
# configuration hot reload supported with a
# directory volume mount
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
-v /path/to/config:/config \
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
```
<details>
<summary>
more examples
</summary>
```shell
# pull latest images per platform
docker pull ghcr.io/mostlygeek/llama-swap:cpu
docker pull ghcr.io/mostlygeek/llama-swap:cuda
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
docker pull ghcr.io/mostlygeek/llama-swap:intel
docker pull ghcr.io/mostlygeek/llama-swap:musa
# tagged llama-swap, platform and llama-server version images
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
# non-root cuda
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
```
</details>
### Homebrew Install (macOS/Linux)
```shell
brew tap mostlygeek/llama-swap
brew install llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
### WinGet Install (Windows)
> [!NOTE]
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
```shell
# install
C:\> winget install llama-swap
# upgrade
C:\> winget upgrade llama-swap
```
### Pre-built Binaries
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
### Building from source
1. Building requires Go and Node.js (for UI).
1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all`
1. look in the `build/` subdirectory for the llama-swap binary
## Configuration
```yaml
# minimum viable config.yaml
models:
model1:
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
```
That's all you need to get started:
1. `models` - holds all model configurations
2. `model1` - the ID used in API calls
3. `cmd` - the command to run to start the server.
4. `${PORT}` - an automatically assigned port number
Almost all configuration settings are optional and can be added one step at a time:
- Advanced features
- `matrix` to run concurrent models with a custom swap logic DSL
- `hooks` to run things on startup
- `macros` reusable snippets
- Model customization
- `ttl` to automatically unload models
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `${PORT}` automatic port variables for dynamic port assignment
- `filters` rewrite parts of requests before sending to the upstream server
See the [configuration documentation](docs/configuration.md) for all options.
## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml
llama-swap is managed entirely through a yaml configuration file.
It can be very minimal to start:
```yaml
models:
"qwen2.5":
cmd: |
/path/to/llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port ${PORT}
```
However, there are many more capabilities that llama-swap supports:
- `groups` to run multiple models at once
- `ttl` to automatically unload models
- `macros` for reusable snippets
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
- `env` to pass custom environment variables to inference servers
- `cmdStop` for to gracefully stop Docker/Podman containers
- `useModelName` to override model names sent to upstream servers
- `healthCheckTimeout` to control model startup wait times
- `${PORT}` automatic port variables for dynamic port assignment
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, using a `matrix` allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## Reverse Proxy Configuration (nginx)
@@ -97,129 +224,28 @@ location /v1/chat/completions {
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
## Web UI
## Monitoring Logs on the CLI
llama-swap includes a real time web interface for monitoring logs and models:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
The Activity Page shows recent requests:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
## Installation
llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. From release binaries
4. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker images with llama-swap and llama-server are built nightly.
```shell
# use CPU inference comes with the example config above
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
# qwen2.5 0.5B
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
# SmolLM2 135M
$ curl -s http://localhost:9292/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
jq -r '.choices[0].message.content'
```
<details>
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
They include:
- `ghcr.io/mostlygeek/llama-swap:cpu`
- `ghcr.io/mostlygeek/llama-swap:cuda`
- `ghcr.io/mostlygeek/llama-swap:intel`
- `ghcr.io/mostlygeek/llama-swap:vulkan`
- ROCm disabled until fixed in llama.cpp container
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
```
</details>
### Homebrew Install (macOS/Linux)
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
```shell
# Set up tap and install formula
brew tap mostlygeek/llama-swap
brew install llama-swap
# Run llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source
1. Build requires golang and nodejs for the user interface.
1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all`
1. Binaries will be in `build/` subdirectory
## Monitoring Logs
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
CLI access is also supported:
```shell
```sh
# sends up to the last 10KB of logs
curl http://host/logs'
$ curl http://host/logs
# streams combined logs
curl -Ns 'http://host/logs/stream'
curl -Ns http://host/logs/stream
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# stream llama-swap's proxy status logs
curl -Ns http://host/logs/stream/proxy
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream logs from upstream processes that llama-swap loads
curl -Ns http://host/logs/stream/upstream
# stream logs only from a specific model
curl -Ns http://host/logs/stream/{model_id}
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
# skips history and just streams new log entries
# appending ?no-history will disable sending buffered history first
curl -Ns 'http://host/logs/stream?no-history'
```
@@ -227,11 +253,11 @@ curl -Ns 'http://host/logs/stream?no-history'
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
## Star History
> [!NOTE]
> ⭐️ Star this project to help others discover it!
> ⭐️ Star this project to help others discover it!
[![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date)
@@ -0,0 +1,85 @@
# Replace ring.Ring with Efficient Circular Byte Buffer
## Overview
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
## Current Issues
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
4. Linked list structure has poor cache locality and pointer overhead
## Design Requirements
### New CircularBuffer Type
Create a simple circular byte buffer with:
- Single pre-allocated `[]byte` of fixed capacity (10KB)
- `head` and `size` integers to track write position and data length
- No per-write allocations
### API Requirements
The new buffer must support:
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
### Implementation Details
```go
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
```
**Write logic:**
- If `len(p) >= capacity`: just keep the last `capacity` bytes
- Otherwise: write bytes at `head`, wrapping around if needed
- Update `head` and `size` accordingly
- Data is copied into the internal buffer (not stored by reference)
**GetHistory logic:**
- Calculate start position: `(head - size + cap) % cap`
- If not wrapped: single slice copy
- If wrapped: two copies (end of buffer + beginning)
- Returns a **new slice** (copy), not a view into internal buffer
### Immutability Guarantees (must preserve)
Per existing tests:
1. Modifying input `[]byte` after `Write()` must not affect stored data
2. `GetHistory()` returns independent copy - modifications don't affect buffer
## Files to Modify
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
## Testing Plan
Existing tests in `logMonitor_test.go` should continue to pass:
- `TestLogMonitor` - Basic write/read and subscriber notification
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
- `TestWrite_LogTimeFormat` - Timestamp formatting
Add new tests:
- Test buffer wrap-around behavior
- Test large writes that exceed buffer capacity
- Test exact capacity boundary conditions
## Checklist
- [ ] Create `circularBuffer` struct in `logMonitor.go`
- [ ] Implement `Write()` method for circular buffer
- [ ] Implement `GetHistory()` method for circular buffer
- [ ] Update `LogMonitor` struct to use new buffer
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
- [ ] Update `LogMonitor.Write()` to use new buffer
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
- [ ] Remove `"container/ring"` import
- [ ] Run `make test-dev` to verify existing tests pass
- [ ] Add wrap-around test case
- [ ] Run `make test-all` for final validation
+183
View File
@@ -0,0 +1,183 @@
# Improve Testability (#655)
## Current Pain Points
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
## Stages
### Stage 1: YAML-based test config helper
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
Create a test helper in `proxy/helpers_test.go`:
```go
// testConfigFromYAML substitutes simple-responder paths and loads through
// the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
```
Tests would then look like:
```go
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
model2:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
`)
proxy := New(config)
// ... same assertions
}
```
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
**2a. Add testHandler to Process:**
```go
// In Process struct (process.go):
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
```
In `Process.Start()`, skip subprocess + health check when handler is set:
```go
func (p *Process) start() error {
if p.testHandler != nil {
p.setState(StateReady)
return nil
}
// existing subprocess logic...
}
```
In `Process.ProxyRequest()`, delegate directly to the handler:
```go
// Before the reverseProxy.ServeHTTP call:
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
return
}
```
**2b. Test helper to create the handler:**
```go
// newTestHandler returns an http.Handler that mimics llama.cpp's API
// (same endpoints as simple-responder).
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
// ... other endpoints
return mux
}
```
Tests for routing/auth/CORS/streaming then become:
```go
func TestProxyManager_AuthRequired(t *testing.T) {
handler := newTestHandler("model1")
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
requiredAPIKeys: [test-key]
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
`)
pm := NewProxyManager(config)
// inject handler — skips subprocess, health check, port allocation
pm.processGroups["model1"].process.testHandler = handler
}
```
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
### Stage 3: Migrate tests incrementally
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
Priority order:
1. `proxymanager_test.go` routing tests (highest count, most repetition)
2. `peerproxy_test.go` (straightforward, all HTTP routing)
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
Tests that **must keep simple-responder:**
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
**Scope:** ~60-70% of tests can drop simple-responder.
### Stage 4 (optional): Process interface for ProcessGroup
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
```go
type ProcessController interface {
Start() error
Stop(StopStrategy)
ProxyRequest(http.ResponseWriter, *http.Request) error
CurrentState() ProcessState
ID() string
SetState(ProcessState) // for test setup
}
```
This requires:
- Extracting the interface
- A `MockProcess` implementation
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
## Effort/Impact Summary
| Stage | Effort | Impact | Risk |
|-------|--------|--------|------|
| 1. YAML config helper | Low | Config bugs caught earlier | None |
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
| 4. Process interface | High | Pure unit tests possible | Medium |
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
+292
View File
@@ -0,0 +1,292 @@
# Add Model Metadata Support with Typed Macros
## Overview
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
## Design Requirements
### 1. Enhanced Macro System
**Current State:**
- Macros are defined as `map[string]string` at both global and model levels
- Only string substitution is supported
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
**Required Changes:**
- Change `MacroList` type from `map[string]string` to `map[string]any`
- Support scalar types: `string`, `int`, `float64`, `bool`
- Implement type-preserving macro substitution:
- Direct macro usage (`key: ${macro}`) preserves the macro's type
- Interpolated usage (`key: "text ${macro}"`) converts to string
- Add validation to ensure macro values are scalar types only
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
**Implementation Details:**
- Create a generic helper function to perform macro substitution that:
- Takes a value of type `any`
- Recursively processes maps, slices, and scalar values
- Replaces `${macro_name}` patterns with macro values
- Preserves types for direct substitution
- Converts to strings for interpolated substitution
- Update `validateMacro()` function to accept `any` type and validate scalar types
- Maintain backward compatibility with existing string-only macros
### 2. Metadata Field in ModelConfig
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
**Required Changes:**
- Add `Metadata map[string]any` field to `ModelConfig` struct
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
- Apply macro substitution to metadata values during config loading
**Schema Requirements:**
- Metadata is optional (default: empty/nil map)
- Supports nested structures (objects within objects, arrays, etc.)
- All string values within metadata undergo macro substitution
- Type preservation rules apply as described above
### 3. Macro Substitution in Metadata
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
**Process Flow:**
1. After loading YAML configuration
2. After model-level and global macro merging
3. Apply macro substitution to `ModelConfig.Metadata` field
4. Use the same merged macros available to `cmd`, `proxy`, etc.
5. Process recursively through all nested structures
**Substitution Rules:**
- `port: ${PORT}` → keeps integer type from PORT macro
- `temperature: ${temp}` → keeps float type from temp macro
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
- Arrays and nested objects are processed recursively
- Unknown macros should cause configuration load error (consistent with existing behavior)
### 4. API Response Updates
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
**Current Behavior:**
- Returns model records with: `id`, `object`, `created`, `owned_by`
- Optionally includes: `name`, `description`
**Required Changes:**
- Add metadata to each model record under the key `llamaswap_meta`
- Only include `llamaswap_meta` if metadata is non-empty
- Preserve all types when marshaling to JSON
- Maintain existing sorting by model ID
**Example Response:**
```json
{
"object": "list",
"data": [
{
"id": "llama",
"object": "model",
"created": 1234567890,
"owned_by": "llama-swap",
"name": "llama 3.1 8B",
"description": "A small but capable model",
"llamaswap_meta": {
"port": 10001,
"temperature": 0.7,
"note": "The llama is running on port 10001 temp=0.7, context=16384",
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
"an_obj": {
"a": "1",
"b": 2,
"c": [0.7, false, "model: llama"]
}
}
}
]
}
```
### 5. Validation and Error Handling
**Macro Validation:**
- Extend `validateMacro()` to accept values of type `any`
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
- Reject complex types (maps, slices, structs) as macro values
- Maintain existing validation for macro names and lengths
**Configuration Loading:**
- Fail fast if unknown macros are found in metadata
- Provide clear error messages indicating which model and field contains errors
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
## Testing Plan
### Test 1: Model-Level Macros with Different Types
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
**Test Cases:**
- Define model with macros of each scalar type
- Verify metadata correctly substitutes and preserves types
- Test direct substitution (`port: ${PORT}`)
- Test string interpolation (`note: "Port is ${PORT}"`)
- Verify nested objects and arrays work correctly
### Test 2: Global and Model Macro Precedence
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Define same macro at global and model level with different types
- Verify model-level macro takes precedence
- Test metadata uses correct macro value
- Verify type is preserved from the winning macro
### Test 3: Macro Validation
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Test that complex types (maps, arrays) are rejected as macro values
- Verify error message includes: macro name and type that was rejected
- Test that scalar types (string, int, float, bool) are accepted
- Each type should load without error
- Test macro name validation still works with `any` types
- Invalid characters, reserved names, length limits should still be enforced
### Test 4: Metadata in API Response
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
**Existing Test:** `TestProxyManager_ListModelsHandler`
**Test Cases:**
- Model with metadata → verify `llamaswap_meta` key appears
- Model without metadata → verify `llamaswap_meta` key is absent
- Verify all types are correctly marshaled to JSON
- Verify nested structures are preserved
- Verify macro substitution has occurred before serialization
### Test 5: Unknown Macros in Metadata
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Use undefined macro in metadata
- Verify configuration loading fails with clear error
- Error should indicate model name and that macro is undefined
### Test 6: Recursive Substitution
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
**Test Cases:**
- Metadata with deeply nested structures
- Arrays containing objects with macros
- Objects containing arrays with macros
- Mixed string interpolation and direct substitution at various nesting levels
## Checklist
### Configuration Schema Changes
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
- [x] Update `validateMacro()` function signature to accept `any` type for values
- [x] Add validation logic to ensure macro values are scalar types only
### Macro Substitution Logic
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
- [x] Implement type-preserving direct substitution logic
- [x] Implement string interpolation with type conversion
- [x] Handle maps: recursively process all values
- [x] Handle slices: recursively process all elements
- [x] Handle scalar types: perform string-based macro substitution if value is string
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
- [x] Update existing macro substitution calls to use merged macros with correct types
### API Response Changes
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
- [x] Add `llamaswap_meta` field to model records when metadata exists
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
- [x] Verify JSON marshaling preserves all types correctly
### Testing - Config Package
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
### Testing - Model Config Package
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
- [x] Test metadata with various scalar types
- [x] Test metadata with nested objects and arrays
### Testing - Proxy Manager
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
- [x] Add test case for model with metadata
- [x] Add test case for model without metadata
- [x] Verify `llamaswap_meta` key presence/absence
- [x] Verify type preservation in JSON output
- [x] Verify macro substitution has occurred
### Documentation
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
- [x] No additional documentation needed per project instructions
## Known Issues and Considerations
### Inconsistencies
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
### Design Decisions
1. **Why `llamaswap_meta` instead of merging into record?**
- Avoids potential collisions with OpenAI API standard fields
- Makes it clear this is llama-swap specific metadata
- Easier for clients to distinguish standard vs. custom fields
2. **Why support nested structures?**
- Provides maximum flexibility for users
- Aligns with the schemaless design principle
- Example config already demonstrates this capability
3. **Why validate macro types?**
- Prevents confusing behavior (e.g., substituting a map)
- Makes configuration errors explicit at load time
- Simpler implementation and testing
+397
View File
@@ -0,0 +1,397 @@
# Improve macro-in-macro support
**Status: COMPLETED ✅**
## Title
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
## Overview
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
**Outcomes:**
- Macros can reference other macros defined earlier in the config
- Macro substitution is deterministic and order-dependent
- Single-pass substitution prevents circular dependencies
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
- All existing tests pass
- New tests validate substitution order and self-reference detection
## Design Requirements
### 1. YAML Parsing Strategy
- **Continue using:** `gopkg.in/yaml.v3` (current library)
- **Use:** `yaml.Node` for ordered parsing of macros
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
### 2. Data Structure Changes
#### Current Implementation (config.go:19)
```go
type MacroList map[string]any
```
#### New Implementation
```go
type MacroList []MacroEntry
type MacroEntry struct {
Name string
Value any
}
```
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
### 3. Macro Substitution Order Rules
The substitution must follow this hierarchy (from most specific to least):
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
2. **Model-level macros** (middle): Defined in specific model config, overrides global
3. **Global macros** (first): Defined at config root level
Within each level, macros are substituted in **reverse definition order** (LIFO):
- The last macro defined is substituted first
- This allows later macros to reference earlier ones
- Single-pass substitution prevents circular dependencies
### 4. Macro Reference Rules
**Allowed:**
- Macro can reference any macro defined **before** it (earlier in the file)
- Model macros can reference global macros
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
**Prohibited:**
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
- Macro cannot reference macros defined **after** it
- No circular references (prevented by single-pass, ordered substitution)
### 5. Validation Requirements
Add validation to detect:
- **Self-references:** Macro value contains reference to its own name
- **Unknown macros:** After substitution, any remaining `${...}` references
Error messages should be clear:
```
macro 'foo' contains self-reference
unknown macro '${bar}' in model.cmd
```
### 6. Implementation Changes
#### Files to Modify
1. **[proxy/config/config.go](proxy/config/config.go)**
- Line 19: Change `MacroList` type definition
- Line 69: Update `Macros MacroList` field
- Line 153-157: Update macro validation loop to work with ordered structure
- Line 175-188: Update model-level macro validation
- Line 181-188: **NEW** Implement proper macro merging respecting order
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
- Line 389-415: Update `validateMacro` to detect self-references
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
- Line 33: Update `Macros MacroList` field type
3. **All test files**
- Update test fixtures to use ordered macro definitions
- Ensure tests specify macro order explicitly
#### Core Algorithm
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
```go
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
// Add global macros first
for _, entry := range config.Macros {
mergedMacros = append(mergedMacros, entry)
}
// Add model macros (can override global)
for _, entry := range modelConfig.Macros {
// Remove any existing global macro with same name
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry // Override
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
// Add reserved MODEL_ID macro at the end
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
// Check if PORT macro is needed
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// Add PORT macro to the end (highest priority)
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
nextPort++
}
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
// This allows later macros to reference earlier ones
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
// Substitute in command fields
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in metadata (recursive)
if len(modelConfig.Metadata) > 0 {
var err error
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
}
}
```
Add this new helper function to replace `substituteMetadataMacros`:
```go
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}
```
### 7. Self-Reference Detection
Add to `validateMacro` function:
```go
func validateMacro(name string, value any) error {
// ... existing validation ...
// Check for self-reference
if str, ok := value.(string); ok {
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(str, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
}
return nil
}
```
## Testing Plan
### 1. Migration Tests
- **Test:** All existing macro tests still pass after YAML library migration
- **Files:** All `*_test.go` files with macro tests
### 2. Macro Order Tests
#### Test: Macro-in-macro substitution order
```yaml
macros:
"A": "value-A"
"B": "prefix-${A}-suffix"
models:
test:
cmd: "echo ${B}"
```
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
#### Test: LIFO substitution order
```yaml
macros:
"base": "/models"
"path": "${base}/llama"
"full": "${path}/model.gguf"
models:
test:
cmd: "load ${full}"
```
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
#### Test: Model macro overrides global
```yaml
macros:
"tag": "global"
"msg": "value-${tag}"
models:
test:
macros:
"tag": "model-level"
cmd: "echo ${msg}"
```
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
### 3. Reserved Macro Tests
#### Test: MODEL_ID substituted in macro
```yaml
macros:
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
models:
my-model:
cmd: "${podman-llama} -m model.gguf"
```
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
### 4. Error Detection Tests
#### Test: Self-reference detection
```yaml
macros:
"recursive": "value-${recursive}"
```
**Expected:** Error: `macro 'recursive' contains self-reference`
#### Test: Undefined macro reference
```yaml
macros:
"A": "value-${UNDEFINED}"
```
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
### 5. Regression Tests
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
- Ensure all pass without modification (except test fixtures if needed)
## Checklist
### Phase 1: Data Structure Changes
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
- [ ] Implement helper functions:
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
### Phase 2: Macro Validation Updates
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
- [ ] Test self-reference detection with new test case
### Phase 3: Macro Substitution Algorithm
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
- [ ] Substitute in metadata within same loop
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
- [ ] Ensure `PORT` is added after port assignment (if needed)
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
### Phase 4: Testing
- [ ] Run `make test-dev` - fix any static checking errors
- [ ] Add test: macro-in-macro basic substitution
- [ ] Add test: LIFO substitution order with 3+ macro levels
- [ ] Add test: MODEL_ID in global macro used by model
- [ ] Add test: PORT in global macro used by model
- [ ] Add test: model macro overrides global macro in substitution
- [ ] Add test: self-reference detection error
- [ ] Add test: undefined macro reference error
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
### Phase 5: Documentation
- [ ] Update plan status in this file (mark completed)
- [ ] Update CLAUDE.md if macro behavior needs documentation
- [ ] Verify no new error messages need user documentation
## Bug Example (Original Issue)
```yaml
macros:
"podman-llama": >
podman run --name ${MODEL_ID}
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
ghcr.io/ggml-org/llama.cpp:server-cuda
"standard-options": >
--no-mmap --jinja
"kv8": >
-fa on -ctk q8_0 -ctv q8_0
```
**Current Bug:**
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
**After Fix:**
- Macros substituted in LIFO order: `kv8``standard-options``podman-llama`
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
@@ -210,6 +210,11 @@ func main() {
})
})
r.GET("/v1/audio/voices", func(c *gin.Context) {
model := c.Query("model")
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
})
r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo")
delay := c.Query("delay")
@@ -269,6 +274,43 @@ func main() {
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
})
// SD API endpoints
r.POST("/sdapi/v1/txt2img", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
c.JSON(http.StatusOK, gin.H{
"model": modelName,
"images": []string{},
})
})
r.POST("/sdapi/v1/img2img", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
c.JSON(http.StatusOK, gin.H{
"model": modelName,
"images": []string{},
})
})
r.GET("/sdapi/v1/loras", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"loras": []string{},
})
})
address := "127.0.0.1:" + *port // Address with the specified port
srv := &http.Server{
+27
View File
@@ -0,0 +1,27 @@
# wol-proxy
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
## Usage
```shell
# minimal
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
# everything
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
# use debug log level
-log debug \
# altenerative listening port
-listen localhost:9999 \
# seconds to hold requests waiting for upstream to be ready
-timeout 30
```
## API
`GET /status` - that's it. Everything else is proxied to the upstream server.
+64
View File
@@ -0,0 +1,64 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Loading...</title>
<style>
body {
font-family: sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: #f5f5f5;
}
.loader {
text-align: center;
}
.stats {
font-size: 18px;
color: #333;
margin: 20px 0;
}
.stats-label {
color: #666;
font-size: 14px;
}
</style>
</head>
<body>
<div class="loader">
<p>Waking up upstream server...</p>
<div class="stats">
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
<div><span id="attempts">&nbsp;</span></div>
</div>
</div>
<script>
var startTime = Date.now();
var attempts = 0;
setInterval(function() {
var elapsed = (Date.now() - startTime) / 1000;
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
}, 100);
// Check status every second
setInterval(function() {
attempts++;
var dots = '.'.repeat((attempts % 10) || 10);
document.getElementById('attempts').textContent = dots;
fetch('/status')
.then(function(r) { return r.text(); })
.then(function(t) {
if (t.indexOf('status: ready') !== -1) {
location.reload();
}
})
.catch(function() {});
}, 1000);
</script>
</body>
</html>
+333
View File
@@ -0,0 +1,333 @@
package main
import (
"bufio"
"context"
_ "embed"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"time"
)
//go:embed index.html
var loadingPageHTML string
var (
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
flagListen = flag.String("listen", ":8080", "listen address to listen on")
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
)
func main() {
flag.Parse()
switch *flagLog {
case "debug":
slog.SetLogLoggerLevel(slog.LevelDebug)
case "info":
slog.SetLogLoggerLevel(slog.LevelInfo)
case "warn":
slog.SetLogLoggerLevel(slog.LevelWarn)
case "error":
slog.SetLogLoggerLevel(slog.LevelError)
default:
slog.Error("invalid log level", "logLevel", *flagLog)
return
}
// Validate flags
if *flagListen == "" {
slog.Error("listen address is required")
return
}
if *flagMac == "" {
slog.Error("mac address is required")
return
}
if *flagTimeout < 1 {
slog.Error("timeout must be greater than 0")
return
}
var upstreamURL *url.URL
var err error
// validate mac address
if _, err = net.ParseMAC(*flagMac); err != nil {
slog.Error("invalid mac address", "error", err)
return
}
if *flagUpstream == "" {
slog.Error("upstream proxy address is required")
return
} else {
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
if err != nil {
slog.Error("error parsing upstream url", "error", err)
return
}
}
proxy := newProxy(upstreamURL)
server := &http.Server{
Addr: *flagListen,
Handler: proxy,
}
// start the server
go func() {
slog.Info("server starting on", "address", *flagListen)
if err := server.ListenAndServe(); err != nil {
slog.Error("error starting server", "error", err)
}
}()
// graceful shutdown
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
<-ctx.Done()
server.Close()
}
type upstreamStatus string
const (
notready upstreamStatus = "not ready"
ready upstreamStatus = "ready"
)
type proxyServer struct {
upstreamProxy *httputil.ReverseProxy
failCount int
statusMutex sync.RWMutex
status upstreamStatus
}
func newProxy(url *url.URL) *proxyServer {
p := httputil.NewSingleHostReverseProxy(url)
proxy := &proxyServer{
upstreamProxy: p,
status: notready,
failCount: 0,
}
// start a goroutine to monitor upstream status via SSE
go func() {
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
client := &http.Client{
Timeout: 0, // No timeout for SSE connection
}
waitDuration := 10 * time.Second
for {
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
req, err := http.NewRequest("GET", eventsUrl, nil)
if err != nil {
slog.Warn("failed to create SSE request", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(waitDuration)
continue
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
resp, err := client.Do(req)
if err != nil {
slog.Error("failed to connect to SSE endpoint", "error", err)
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
if resp.StatusCode != http.StatusOK {
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
proxy.setStatus(notready)
proxy.incFail(1)
time.Sleep(10 * time.Second)
continue
}
// Successfully connected to SSE endpoint
slog.Info("connected to SSE endpoint, upstream ready")
proxy.setStatus(ready)
proxy.resetFailures()
// Read from the SSE stream to detect disconnection
scanner := bufio.NewScanner(resp.Body)
// use a fairly large buffer to avoid scanner errors when reading large SSE events
buf := make([]byte, 0, 1024*1024*2)
scanner.Buffer(buf, 1024*1024*2)
events := 0
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
fmt.Print("Events: ")
}
for scanner.Scan() {
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
// Just read the events to keep connection alive
// We don't need to process the event data
events++
fmt.Printf("%d, ", events)
}
}
fmt.Println()
if err := scanner.Err(); err != nil {
slog.Error("error reading from SSE stream", "error", err)
}
// Connection closed or error occurred
_ = resp.Body.Close()
slog.Info("SSE connection closed, upstream not ready")
proxy.setStatus(notready)
proxy.incFail(1)
// Wait before reconnecting
time.Sleep(waitDuration)
}
}()
return proxy
}
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && r.URL.Path == "/status" {
status := string(p.getStatus())
failCount := p.getFailures()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
fmt.Fprintf(w, "status: %s\n", status)
fmt.Fprintf(w, "failures: %d\n", failCount)
return
}
if p.getStatus() == notready {
path := r.URL.Path
if strings.HasPrefix(path, "/api/events") {
slog.Debug("Skipping wake up", "req", path)
w.WriteHeader(http.StatusNoContent)
return
}
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
if err := sendMagicPacket(*flagMac); err != nil {
slog.Warn("failed to send magic WoL packet", "error", err)
}
// For root or UI path requests, return loading page with status polling
// the web page will do the polling and redirect when ready
if path == "/" || strings.HasPrefix(path, "/ui/") {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, loadingPageHTML)
return
}
ticker := time.NewTicker(250 * time.Millisecond)
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
defer cancel()
loop:
for {
select {
case <-timeout.Done():
slog.Info("timeout waiting for upstream to be ready")
http.Error(w, "timeout", http.StatusRequestTimeout)
return
case <-ticker.C:
if p.getStatus() == ready {
ticker.Stop()
break loop
}
}
}
}
p.upstreamProxy.ServeHTTP(w, r)
}
func (p *proxyServer) getStatus() upstreamStatus {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.status
}
func (p *proxyServer) setStatus(status upstreamStatus) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.status = status
}
func (p *proxyServer) incFail(num int) {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount += num
}
func (p *proxyServer) getFailures() int {
p.statusMutex.RLock()
defer p.statusMutex.RUnlock()
return p.failCount
}
func (p *proxyServer) resetFailures() {
p.statusMutex.Lock()
defer p.statusMutex.Unlock()
p.failCount = 0
}
func sendMagicPacket(macAddr string) error {
hwAddr, err := net.ParseMAC(macAddr)
if err != nil {
return err
}
if len(hwAddr) != 6 {
return errors.New("invalid MAC address")
}
// Create the magic packet.
packet := make([]byte, 102)
// Add 6 bytes of 0xFF.
for i := 0; i < 6; i++ {
packet[i] = 0xFF
}
// Repeat the MAC address 16 times.
for i := 1; i <= 16; i++ {
copy(packet[i*6:], hwAddr)
}
// Send the packet using UDP.
addr := net.UDPAddr{
IP: net.IPv4bcast,
Port: 9,
}
conn, err := net.DialUDP("udp", nil, &addr)
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write(packet)
return err
}
+520
View File
@@ -0,0 +1,520 @@
{
"$schema": "https://json-schema.org/draft-07/schema#",
"$id": "llama-swap-config-schema.json",
"title": "llama-swap configuration",
"description": "Configuration file for llama-swap",
"type": "object",
"required": [
"models"
],
"definitions": {
"macros": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string",
"minLength": 0,
"maxLength": 1024
},
{
"type": "number"
},
{
"type": "boolean"
}
]
},
"propertyNames": {
"type": "string",
"minLength": 1,
"maxLength": 64,
"pattern": "^[a-zA-Z0-9_-]+$",
"not": {
"enum": [
"PORT",
"MODEL_ID"
]
}
},
"default": {},
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
},
"timeouts": {
"type": "object",
"properties": {
"connect": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP connection timeout in seconds. Set to 0 to disable."
},
"keepalive": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP keepalive timeout in seconds. Set to 0 to disable."
},
"responseHeader": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Time to wait for response headers in seconds. Set to 0 to disable."
},
"tlsHandshake": {
"type": "integer",
"minimum": 0,
"default": 10,
"description": "TLS handshake timeout in seconds. Set to 0 to disable."
},
"expectContinue": {
"type": "integer",
"minimum": 0,
"default": 1,
"description": "Expect-Continue timeout in seconds. Set to 0 to disable."
},
"idleConn": {
"type": "integer",
"minimum": 0,
"default": 90,
"description": "Idle connection timeout in seconds. Set to 0 to disable."
}
},
"additionalProperties": false,
"description": "Timeout settings for proxy connections."
}
},
"properties": {
"healthCheckTimeout": {
"type": "integer",
"minimum": 15,
"default": 120,
"description": "Number of seconds to wait for a model to be ready to serve requests."
},
"globalTTL": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
},
"logLevel": {
"type": "string",
"enum": [
"debug",
"info",
"warn",
"error"
],
"default": "info",
"description": "Sets the logging value. Valid values: debug, info, warn, error."
},
"logTimeFormat": {
"type": "string",
"enum": [
"",
"ansic",
"unixdate",
"rubydate",
"rfc822",
"rfc822z",
"rfc850",
"rfc1123",
"rfc1123z",
"rfc3339",
"rfc3339nano",
"kitchen",
"stamp",
"stampmilli",
"stampmicro",
"stampnano"
],
"default": "",
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
},
"metricsMaxInMemory": {
"type": "integer",
"default": 1000,
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
},
"captureBuffer": {
"type": "integer",
"minimum": 0,
"default": 5,
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
},
"startPort": {
"type": "integer",
"default": 5800,
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
},
"sendLoadingState": {
"type": "boolean",
"default": false,
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
},
"includeAliasesInList": {
"type": "boolean",
"default": false,
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
},
"macros": {
"$ref": "#/definitions/macros"
},
"models": {
"type": "object",
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
"additionalProperties": {
"type": "object",
"required": [
"cmd"
],
"properties": {
"macros": {
"$ref": "#/definitions/macros"
},
"cmd": {
"type": "string",
"minLength": 1,
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
},
"cmdStop": {
"type": "string",
"default": "",
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
},
"name": {
"type": "string",
"default": "",
"maxLength": 128,
"description": "Display name for the model. Used in v1/models API response."
},
"description": {
"type": "string",
"default": "",
"maxLength": 1024,
"description": "Description for the model. Used in v1/models API response."
},
"env": {
"type": "array",
"items": {
"type": "string",
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
},
"default": [],
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
},
"proxy": {
"type": "string",
"default": "http://localhost:${PORT}",
"format": "uri",
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
},
"aliases": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Alternative model names for this configuration. Must be unique globally."
},
"checkEndpoint": {
"type": "string",
"default": "/health",
"pattern": "^/.*$|^none$",
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
},
"ttl": {
"type": "integer",
"minimum": -1,
"default": -1,
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
},
"useModelName": {
"type": "string",
"default": "",
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
},
"filters": {
"type": "object",
"properties": {
"stripParams": {
"type": "string",
"default": "",
"pattern": "^[a-zA-Z0-9_, ]*$",
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
},
"setParams": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
},
"setParamsByID": {
"type": "object",
"additionalProperties": {
"type": "object",
"additionalProperties": true
},
"default": {},
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
}
},
"additionalProperties": false,
"default": {},
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
},
"metadata": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
},
"concurrencyLimit": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
},
"sendLoadingState": {
"type": "boolean",
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
},
"unlisted": {
"type": "boolean",
"default": false,
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
},
"timeouts": {
"$ref": "#/definitions/timeouts"
}
}
}
},
"groups": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"members"
],
"properties": {
"swap": {
"type": "boolean",
"default": true,
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
},
"exclusive": {
"type": "boolean",
"default": true,
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
},
"persistent": {
"type": "boolean",
"default": false,
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
},
"members": {
"type": "array",
"items": {
"type": "string"
},
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
}
}
},
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
},
"matrix": {
"type": "object",
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
"required": [
"vars",
"sets"
],
"properties": {
"vars": {
"type": "object",
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
},
"propertyNames": {
"pattern": "^[a-zA-Z0-9]{1,8}$"
}
},
"evict_costs": {
"type": "object",
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
"additionalProperties": {
"type": "integer",
"minimum": 1
}
},
"sets": {
"type": "object",
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
"minProperties": 1,
"additionalProperties": {
"type": "string"
}
}
},
"additionalProperties": false
},
"hooks": {
"type": "object",
"properties": {
"on_startup": {
"type": "object",
"properties": {
"preload": {
"type": "array",
"items": {
"type": "string"
},
"default": [],
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
}
},
"additionalProperties": false,
"description": "Actions to perform on startup. Only supported action is preload."
}
},
"additionalProperties": false,
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
},
"logToStdout": {
"type": "string",
"enum": [
"proxy",
"upstream",
"both",
"none"
],
"default": "proxy",
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
},
"apiKeys": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
},
"peers": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"proxy",
"models"
],
"properties": {
"proxy": {
"type": "string",
"format": "uri",
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
},
"apiKey": {
"type": "string",
"default": "",
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
},
"models": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"description": "A list of models served by the peer."
},
"filters": {
"type": "object",
"properties": {
"stripParams": {
"type": "string",
"default": "",
"pattern": "^[a-zA-Z0-9_, ]*$",
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
},
"setParams": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
}
},
"additionalProperties": false,
"default": {},
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
},
"timeouts": {
"type": "object",
"properties": {
"connect": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP connection timeout in seconds."
},
"keepalive": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP keepalive connection timeout in seconds."
},
"responseHeader": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Time to wait for response headers in seconds."
},
"tlsHandshake": {
"type": "integer",
"minimum": 0,
"default": 10,
"description": "TLS handshake timeout in seconds."
},
"idleConn": {
"type": "integer",
"minimum": 0,
"default": 90,
"description": "Idle connection timeout in seconds."
}
},
"additionalProperties": false,
"description": "Timeout settings for proxy connections to this peer."
}
}
},
"default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
}
},
"allOf": [
{
"if": {
"required": ["groups"]
},
"then": {
"not": {
"required": ["matrix"]
}
}
},
{
"if": {
"required": ["matrix"]
},
"then": {
"not": {
"required": ["groups"]
}
}
}
]
}
+323 -82
View File
@@ -1,3 +1,6 @@
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example
# -------------------------------------
#
@@ -23,28 +26,103 @@ healthCheckTimeout: 500
# - Valid log levels: debug, info, warn, error
logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# captureBuffer: how many MBs to allocate for storing request/response captures
# - optional, default: 10
# - set to 0 to disable
captureBuffer: 15
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
# - it is automatically incremented for every model that uses it
startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# globalTTL: the default TTL in seconds before unloading a model
# - optional, default: 0 (never automatically unload)
# - must be >= 0
globalTTL: 0
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
# - useful for reducing common configuration settings
# - macro names are strings and must be less than 64 characters
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used
# - environment variables can be referenced with ${env.VAR_NAME} syntax
# - env macros are substituted first, before regular macros
# - if the env var is not set, config loading will fail with an error
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
"default_ctx": 4096
# Example of macro-in-macro usage. macros can contain other macros
# but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}"
# Example of environment variable macros
# - ${env.VAR_NAME} pulls the value from the system environment
# - useful for paths, secrets, or machine-specific configuration
"models_dir": "${env.HOME}/models"
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
# use environment variable macros to keep secrets out of the config
- "${env.API_KEY_1}"
- "${env.API_KEY_2}"
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
@@ -52,9 +130,16 @@ macros:
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"llama":
"gpt-oss-120b":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
# - model level macros follow the same rules as global macros
macros:
"default_ctx": 16384
"temp": 0.7
# cmd: the command to run to start the inference server.
# - required
# - it is just a string, similar to what you would run on the CLI
@@ -63,19 +148,21 @@ models:
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf
--model path/to/gpt-oss-120B.gguf
--ctx-size ${default_ctx}
--temperature ${temp}
# name: a display name for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B"
name: "gpt-oss 120B"
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing"
description: "A thinking model from OpenAI"
# env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array
@@ -90,14 +177,6 @@ models:
# - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# checkEndpoint: URL path to check if the server is ready
# - optional, default: /health
# - endpoint is expected to return an HTTP 200 response
@@ -106,8 +185,10 @@ models:
checkEndpoint: /custom-endpoint
# ttl: automatically unload the model after ttl seconds
# - optional, default: 0
# - ttl values must be a value greater than 0
# - optional, default: -1 (use global default)
# - ttl values must be a value greater than or equal to 0
# - a ttl of -1 will use the global TTL value as the default
# - a ttl of 0 will mean never unload
# - a value of 0 disables automatic unloading of the model
ttl: 60
@@ -115,19 +196,80 @@ models:
# - optional, default: ""
# - useful for when the upstream server expects a specific model name that
# is different from the model's ID
useModelName: "qwen:qwq"
useModelName: "openai/gpt-oss-120B"
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only strip_params is currently supported
# - same capabilities as peer filters (stripParams, setParams)
filters:
# strip_params: a comma separated list of parameters to remove from the request
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for server side enforcement of sampling parameters
# - the `model` parameter can never be removed
# - can be any JSON key in the request body
# - recommended to stick to sampling parameters
strip_params: "temperature, top_p, top_k"
stripParams: "temperature, top_p, top_k"
# setParams: a dictionary of parameters to set/override in requests
# - optional, default: empty dictionary
# - useful for enforcing specific parameter values
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - always runs for the model
setParams:
# Example: enforce specific sampling parameters
temperature: 0.7
top_p: 0.9
# setParamsByID: a dictionary of parameters to set based the model ID
# - optional, default: empty dictionary
# - combine with aliases to create variant behaviour without reloading the model
# - parameters are set in the request body JSON
# - run after setParams so it will override any settings
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - model aliases will be automatically created for each key
setParamsByID:
"${MODEL_ID}":
chat_template_kwargs:
reasoning_effort: medium
"${MODEL_ID}:high":
chat_template_kwargs:
reasoning_effort: high
"${MODEL_ID}:low":
chat_template_kwargs:
reasoning_effort: low
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
# - metadata is only passed through in /v1/models responses
metadata:
# port will remain an integer
port: ${PORT}
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
a_list:
- 1
- 1.23
- "macros are OK in list and dictionary types: ${MODEL_ID}"
an_obj:
a: "1"
b: 2
# objects can contain complex types with macro substitution
# becomes: c: [0.7, false, "model: llama"]
c: ["${temp}", false, "model: ${MODEL_ID}"]
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0
@@ -138,6 +280,26 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below
# - useful for models running on slower hardware that need longer timeouts
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
# - idleConn: idle connection timeout in seconds, default: 90 seconds
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 0
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
@@ -169,68 +331,83 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provides advanced controls over model swapping behaviour
# - using groups some models can be kept loaded indefinitely, while others are swapped out
# - model IDs must be defined in the Models section
# - a model can only be a member of one group
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
# - see issue #109 for details
# =============================================================================
# matrix: run concurrent models with a solver-based swap DSL
# =============================================================================
#
# NOTE: the example below uses model names that are not defined above for demonstration purposes
groups:
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
# to run a time across the whole llama-swap instance
"group1":
# swap: controls the model swapping behaviour in within the group
# - optional, default: true
# - true : only one model is allowed to run at a time
# - false: all models can run together, no swapping
swap: true
# Note:
# A config must use either a matrix or legacy groups, not both. A configuration error
# will occur if both are defined. Configuration examples for legacy Groups can be found:
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
#
# The matrix declares valid combinations of models that can run concurrently.
# When a model is requested, the solver finds the cheapest way to make it
# available by evicting as few (and least costly) running models as possible.
#
# Solver behavior:
# 1. Request arrives for model X
# 2. If X is already running, forward immediately. Done.
# 3. Find all sets containing X
# 4. For each candidate set, compute cost: sum of evict_costs for
# every running model NOT in that set
# 5. Pick lowest cost candidate. Ties broken by definition order.
# 6. Evict what needs to stop. Start X. Forward request.
#
# Subset semantics: a set [a, b, c] means any subset is valid.
# Only the requested model is started — others are not preloaded.
#
# A model not appearing in any set can only run alone.
#
matrix:
# vars: short names for models (alphanumeric, 1-8 chars)
# - required for sets and evict_costs settings
# - each entry is a short name to a real model ID. Do not use an alias
# - used to keep set DSL logic short and easier to read
# - sets and evict_costs only use identifiers defined in vars
vars:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# exclusive: controls how the group affects other groups
# - optional, default: true
# - true: causes all other groups to unload when this group runs a model
# - false: does not affect other groups
exclusive: true
# evict_costs: relative cost of losing a running model (default: 1)
evict_costs:
v: 50 # vllm backend, slow cold start
L: 30 # 70B weights, slow to load
# members references the models defined above
# required
members:
- "llama"
- "qwen-unlisted"
# sets: named sets of concurrent model combinations
# Values are DSL strings with operators:
# & AND (models run together)
# | OR (alternatives)
# () grouping
# +ref inline another set's expression
#
# Expansion examples:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → expands llms inline, then applies & v
sets:
# LLM + TTS: switching between g/q/m won't evict v
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# Example:
# - in group2 all models can run at the same time
# - when a different group is loaded it causes all running models in this group to unload
"group2":
swap: false
# LLM + TTS + reranker
# expands to: [g,v,e], [q,v,e]
with_rerank: "(g | q) & v & e"
# exclusive: false does not unload other groups when a model in group2 is requested
# - the models in group2 will be loaded but will not unload any other groups
exclusive: false
members:
- "docker-llama"
- "modelA"
- "modelB"
# LLM + image generation, no TTS
# expands to: [g,sd], [q,sd]
creative: "(g | q) & sd"
# Example:
# - a persistent group, prevents other groups from unloading it
"forever":
# persistent: prevents over groups from unloading the models in this group
# - optional, default: false
# - does not affect individual model behaviour
persistent: true
# set swap/exclusive to false to prevent swapping inside the group
# and the unloading of other groups
swap: false
exclusive: false
members:
- "forever-modelA"
- "forever-modelB"
- "forever-modelc"
# 70B model uses all GPUs, can only run alone
# expands to: [L]
full: "L"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
@@ -240,10 +417,74 @@ hooks:
# - optional, default: empty dictionary
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
- "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
# - can be a string or a macro
apiKey: ${env.OPENROUTER_API_KEY}
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
# timeouts: configure proxy connection timeouts for this peer
# - optional, defaults shown below
# - useful when the peer runs on slower hardware
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 30
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# filters: a dictionary of filter settings for peer requests
# - optional, default: empty dictionary
# - same capabilities as model filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for removing parameters that the peer doesn't support
# - the `model` parameter can never be removed
stripParams: "temperature, top_p"
# setParams: a dictionary of parameters to set/override in requests to this peer
# - optional, default: empty dictionary
# - useful for injecting provider-specific settings like data retention policies
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce zero-data-retention for OpenRouter
provider:
data_collection: "deny"
zdr: true
+139 -30
View File
@@ -1,55 +1,164 @@
#!/bin/bash
set -euo pipefail
cd $(dirname "$0")
# use this to test locally, example:
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
# you need read:package scope on the token. Generate a personal access token with
# the scopes: gist, read:org, repo, write:packages
# then: gh auth login (and copy/paste the new token)
LOG_DEBUG=${LOG_DEBUG:-0}
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
log_debug() {
if [ "$LOG_DEBUG" = "1" ]; then
echo "[DEBUG] $*"
fi
}
log_info() {
echo "[INFO] $*"
}
ARCH=$1
PUSH_IMAGES=${2:-false}
# List of allowed architectures
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cuda13" "cpu" "rocm")
# Check if ARCH is in the allowed list
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
exit 1
fi
# Check if GITHUB_TOKEN is set and not empty
if [[ -z "$GITHUB_TOKEN" ]]; then
echo "Error: GITHUB_TOKEN is not set or is empty."
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
log_info "Error: GITHUB_TOKEN is not set or is empty."
exit 1
fi
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
# variable, this permits testing with forked llama.cpp repositories
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
# to enable easy container builds on forked repos
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
# Fetches the most recent llama.cpp tag matching the given prefix
# Handles pagination to search beyond the first 100 results
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
# Returns: the version number extracted from the tag
fetch_llama_tag() {
local tag_prefix=$1
local page=1
local per_page=100
while true; do
log_debug "Fetching page $page for tag prefix: $tag_prefix"
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
# Check for API errors
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
local error_msg=$(echo "$response" | jq -r '.message')
log_info "GitHub API error: $error_msg"
return 1
fi
# Check if response is empty array (no more pages)
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
log_debug "No more pages (empty response)"
return 1
fi
# Extract matching tag from this page
local found_tag=$(echo "$response" | jq -r \
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
| sort -r | head -n1)
if [ -n "$found_tag" ]; then
log_debug "Found tag: $found_tag on page $page"
echo "$found_tag" | awk -F '-' '{print $NF}'
return 0
fi
page=$((page + 1))
# Safety limit to prevent infinite loops
if [ $page -gt 50 ]; then
log_info "Reached pagination safety limit (50 pages)"
return 1
fi
done
}
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the latest available
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
echo "Building ${CONTAINER_LATEST} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_LATEST}
fi
LCPP_TAG=$(fetch_llama_tag "server")
BASE_TAG=server-${LCPP_TAG}
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
BASE_TAG=server-${ARCH}-${LCPP_TAG}
fi
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
fi
SD_TAG=master-${ARCH}
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
echo "Building ${CONTAINER_TAG} $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
fi
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
log_info "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
else
log_info "LCPP_TAG: $LCPP_TAG"
fi
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
log_info "Abort: DEBUG_ABORT_BUILD set"
exit 0
fi
for CONTAINER_TYPE in non-root root; do
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
USER_UID=0
USER_GID=0
USER_HOME=/root
if [ "$CONTAINER_TYPE" == "non-root" ]; then
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
USER_UID=10001
USER_GID=10001
USER_HOME=/app
fi
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
# For architectures with stable-diffusion.cpp support, layer sd-server on top
case "$ARCH" in
"musa" | "vulkan")
log_info "Adding sd-server to $CONTAINER_TAG"
docker build --provenance=false -f llama-swap-sd.Containerfile \
--build-arg BASE=${CONTAINER_TAG} \
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
esac
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
fi
done
+305
View File
@@ -0,0 +1,305 @@
#!/bin/bash
#
# Build script for llama-swap-docker with commit hash pinning
#
# Usage:
# ./build-image.sh --cuda # Build CUDA image
# ./build-image.sh --vulkan # Build Vulkan image
# ./build-image.sh --cuda --no-cache # Build CUDA image without cache
# LLAMA_COMMIT_HASH=abc123 ./build-image.sh --cuda # Override llama.cpp commit
# LLAMA_COMMIT_HASH=b8429 ./build-image.sh --vulkan # Override llama.cpp release tag (vulkan uses prebuilt binaries)
# WHISPER_COMMIT_HASH=def456 ./build-image.sh --vulkan # Override whisper.cpp commit
# SD_COMMIT_HASH=ghi789 ./build-image.sh --cuda # Override stable-diffusion.cpp commit
#
# Features:
# - Auto-detects latest commit hashes from git repos
# - Builds llama-swap from local source code
# - Allows environment variable overrides for reproducible builds
# - Cache-friendly: changing commit hash busts cache appropriately
# - Supports both CUDA and Vulkan backends (requires explicit flag)
#
set -euo pipefail
# Parse command line arguments
BACKEND=""
NO_CACHE=false
if [[ $# -eq 0 ]]; then
echo "Error: No backend specified. Please use --cuda or --vulkan."
echo ""
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
echo ""
echo "Options:"
echo " --cuda Build CUDA image (NVIDIA GPUs)"
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
echo " --no-cache Force rebuild without using Docker cache"
echo " --help, -h Show this help message"
echo ""
echo "Environment variables:"
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
exit 1
fi
for arg in "$@"; do
case $arg in
--cuda)
BACKEND="cuda"
;;
--vulkan)
BACKEND="vulkan"
;;
--no-cache)
NO_CACHE=true
;;
--help|-h)
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
echo ""
echo "Options:"
echo " --cuda Build CUDA image (NVIDIA GPUs)"
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
echo " --no-cache Force rebuild without using Docker cache"
echo " --help, -h Show this help message"
echo ""
echo "Environment variables:"
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
exit 0
;;
esac
done
# Validate backend selection
if [[ -z "$BACKEND" ]]; then
echo "Error: No backend specified. Please use --cuda or --vulkan."
exit 1
fi
# Configuration
if [[ -n "${DOCKER_IMAGE_TAG:-}" ]]; then
# User provided a custom tag, use it as-is
:
elif [[ "$BACKEND" == "vulkan" ]]; then
DOCKER_IMAGE_TAG="llama-swap:vulkan"
else
DOCKER_IMAGE_TAG="llama-swap:cuda"
fi
DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}"
# Single unified Dockerfile, backend selected via build arg
DOCKERFILE="Dockerfile"
if [[ "$BACKEND" == "vulkan" ]]; then
echo "Building for: Vulkan (AMD GPUs and compatible hardware)"
else
echo "Building for: CUDA (NVIDIA GPUs)"
fi
# Git repository URLs
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
# Function to get the latest commit hash from a git repo's default branch
get_latest_commit() {
local repo_url="$1"
local branch="${2:-master}"
# Try to get the latest commit hash for the specified branch
git ls-remote --heads "${repo_url}" "${branch}" 2>/dev/null | head -1 | cut -f1
}
# Function to get the default branch name (master or main)
get_default_branch() {
local repo_url="$1"
# Check for master first
if git ls-remote --heads "${repo_url}" master &>/dev/null; then
echo "master"
elif git ls-remote --heads "${repo_url}" main &>/dev/null; then
echo "main"
else
echo "master" # fallback
fi
}
# Function to get the latest release tag from a GitHub repo
get_latest_release_tag() {
local owner_repo="$1"
curl -fsSL "https://api.github.com/repos/${owner_repo}/releases/latest" \
| grep '"tag_name"' | head -1 | cut -d'"' -f4
}
echo "=========================================="
echo "llama-swap-docker Build Script"
echo "=========================================="
echo ""
# Determine commit hashes / release tags - use env vars or auto-detect
# For vulkan builds, llama and sd use GitHub release tags (prebuilt binaries).
# For cuda builds (or whisper on any backend), use git commit hashes.
if [[ -n "${LLAMA_COMMIT_HASH:-}" ]]; then
LLAMA_HASH="${LLAMA_COMMIT_HASH}"
echo "llama.cpp: Using provided version: ${LLAMA_HASH}"
elif [[ "$BACKEND" == "vulkan" ]]; then
LLAMA_HASH=$(get_latest_release_tag "ggml-org/llama.cpp")
if [[ -z "${LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest release tag for llama.cpp" >&2
exit 1
fi
echo "llama.cpp: Auto-detected latest release tag: ${LLAMA_HASH}"
else
LLAMA_BRANCH=$(get_default_branch "${LLAMA_REPO}")
LLAMA_HASH=$(get_latest_commit "${LLAMA_REPO}" "${LLAMA_BRANCH}")
if [[ -z "${LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
exit 1
fi
echo "llama.cpp: Auto-detected latest commit (${LLAMA_BRANCH}): ${LLAMA_HASH}"
fi
if [[ -n "${WHISPER_COMMIT_HASH:-}" ]]; then
WHISPER_HASH="${WHISPER_COMMIT_HASH}"
echo "whisper.cpp: Using provided commit hash: ${WHISPER_HASH}"
else
WHISPER_BRANCH=$(get_default_branch "${WHISPER_REPO}")
WHISPER_HASH=$(get_latest_commit "${WHISPER_REPO}" "${WHISPER_BRANCH}")
if [[ -z "${WHISPER_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
exit 1
fi
echo "whisper.cpp: Auto-detected latest commit (${WHISPER_BRANCH}): ${WHISPER_HASH}"
fi
if [[ -n "${SD_COMMIT_HASH:-}" ]]; then
SD_HASH="${SD_COMMIT_HASH}"
echo "stable-diffusion.cpp: Using provided version: ${SD_HASH}"
elif [[ "$BACKEND" == "vulkan" ]]; then
SD_HASH=$(get_latest_release_tag "leejet/stable-diffusion.cpp")
if [[ -z "${SD_HASH}" ]]; then
echo "ERROR: Could not determine latest release tag for stable-diffusion.cpp" >&2
exit 1
fi
echo "stable-diffusion.cpp: Auto-detected latest release tag: ${SD_HASH}"
else
SD_BRANCH=$(get_default_branch "${SD_REPO}")
SD_HASH=$(get_latest_commit "${SD_REPO}" "${SD_BRANCH}")
if [[ -z "${SD_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
exit 1
fi
echo "stable-diffusion.cpp: Auto-detected latest commit (${SD_BRANCH}): ${SD_HASH}"
fi
echo ""
echo "=========================================="
echo "Starting Docker build..."
echo "=========================================="
echo ""
# Build the Docker image with commit hashes as build args
# Build context is the repository root (..) so the Dockerfile can access Go source
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
BUILD_ARGS=(
--build-arg "BACKEND=${BACKEND}"
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
-t "${DOCKER_IMAGE_TAG}"
-f "${SCRIPT_DIR}/${DOCKERFILE}"
)
if [[ "$NO_CACHE" == true ]]; then
BUILD_ARGS+=(--no-cache)
echo "Note: Building without cache"
fi
# Use docker buildx with a custom builder for parallelism control
# The legacy DOCKER_BUILDKIT=1 docker build doesn't respect BUILDKIT_MAX_PARALLELISM env var
# We need to use a custom builder with a buildkitd.toml config file
BUILDER_NAME="llama-swap-builder"
# Check if our custom builder exists with the right config, create/update if needed
if ! docker buildx inspect "$BUILDER_NAME" >/dev/null 2>&1; then
echo "Creating custom buildx builder with max-parallelism=1..."
# Create buildkitd.toml config file
cat > buildkitd.toml << 'BUILDKIT_EOF'
[worker.oci]
max-parallelism = 1
BUILDKIT_EOF
# Create the builder with the config
docker buildx create --name "$BUILDER_NAME" \
--driver docker-container \
--buildkitd-config buildkitd.toml \
--use
else
# Switch to our builder
docker buildx use "$BUILDER_NAME"
fi
echo "Building with sequential stages (one at a time), each using all CPU cores..."
echo "Using builder: $BUILDER_NAME"
# Use docker buildx build with --load to load the image into Docker
# The --builder flag ensures we use our custom builder with max-parallelism=1
# Build context is the repository root so we can access Go source files
docker buildx build --builder "$BUILDER_NAME" --load "${BUILD_ARGS[@]}" "${REPO_ROOT}"
echo ""
echo "=========================================="
echo "Verifying build artifacts..."
echo "=========================================="
echo ""
# Verify all expected binaries exist in the image
MISSING_BINARIES=()
for binary in llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap; do
if ! docker run --rm "${DOCKER_IMAGE_TAG}" which "${binary}" >/dev/null 2>&1; then
MISSING_BINARIES+=("${binary}")
fi
done
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
echo "ERROR: Build succeeded but the following binaries are missing from the image:"
for binary in "${MISSING_BINARIES[@]}"; do
echo " - ${binary}"
done
echo ""
echo "This usually indicates a build stage failure. Try running with --no-cache flag:"
echo " ./build-image.sh --vulkan --no-cache"
exit 1
fi
echo "All expected binaries verified: llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
echo ""
echo "=========================================="
echo "Build complete!"
echo "=========================================="
echo ""
echo "Image tag: ${DOCKER_IMAGE_TAG}"
echo ""
echo "Built with:"
echo " llama.cpp: ${LLAMA_HASH}"
echo " whisper.cpp: ${WHISPER_HASH}"
echo " stable-diffusion.cpp: ${SD_HASH}"
echo " llama-swap: $(docker run --rm "${DOCKER_IMAGE_TAG}" cat /versions.txt | grep llama-swap | cut -d' ' -f2-)"
echo ""
if [[ "$BACKEND" == "vulkan" ]]; then
echo "Run with:"
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
echo ""
echo "Note: For AMD GPUs, you may also need to mount render devices:"
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
else
echo "Run with:"
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
fi
+16 -1
View File
@@ -15,4 +15,19 @@ models:
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
--port 9999
z-image:
checkEndpoint: /
cmd: |
/app/sd-server
--listen-port 9999
--diffusion-fa
--diffusion-model /models/z_image_turbo-Q8_0.gguf
--vae /models/ae.safetensors
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
--offload-to-cpu
--cfg-scale 1.0
--height 512 --width 512
--steps 8
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
+11
View File
@@ -0,0 +1,11 @@
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
ARG SD_TAG=master-vulkan
ARG BASE=llama-swap:latest
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
FROM ${BASE}
ARG UID=10001
ARG GID=10001
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
+36 -8
View File
@@ -1,16 +1,44 @@
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
FROM ${BASE_IMAGE}:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=89
ARG LS_VER=170
ARG LS_REPO=mostlygeek/llama-swap
# Set default UID/GID arguments
ARG UID=10001
ARG GID=10001
ARG USER_HOME=/app
# Add user/group
ENV HOME=$USER_HOME
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
groupadd --system --gid $GID app; \
fi; \
useradd --system --uid $UID --gid $GID \
--home $USER_HOME app; \
fi
# Handle paths
RUN mkdir --parents $HOME /app
RUN chown --recursive $UID:$GID $HOME /app
# Switch user
USER $UID:$GID
WORKDIR /app
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
COPY config.example.yaml /app/config.yaml
# Add /app to PATH
ENV PATH="/app:${PATH}"
RUN \
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
+204
View File
@@ -0,0 +1,204 @@
# Unified multi-stage Dockerfile for AI inference tools
# Supports CUDA and Vulkan backends via BACKEND build arg
#
# Usage:
# docker buildx build --build-arg BACKEND=cuda -t llama-swap:unified-cuda .
# docker buildx build --build-arg BACKEND=vulkan -t llama-swap:unified-vulkan .
# docker buildx build --build-arg BACKEND=cuda --build-arg CMAKE_CUDA_ARCHITECTURES="86;89" -t llama-swap:unified-cuda .
#
# Each project has its own install script that handles cloning, building,
# and installing binaries. Build stages are independent for cache efficiency.
ARG BACKEND=cuda
# ── Builder bases ──────────────────────────────────────────────────────
FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 AS builder-base-cuda
ARG CMAKE_CUDA_ARCHITECTURES="60;61;75;86;89"
ENV DEBIAN_FRONTEND=noninteractive
ENV CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}
ENV CCACHE_DIR=/ccache
ENV CCACHE_MAXSIZE=2G
ENV PATH="/usr/lib/ccache:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
# ──
FROM ubuntu:24.04 AS builder-base-vulkan
ENV DEBIAN_FRONTEND=noninteractive
ENV CCACHE_DIR=/ccache
ENV CCACHE_MAXSIZE=2G
ENV PATH="/usr/lib/ccache:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget software-properties-common \
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
spirv-headers \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
# ── Select builder base by BACKEND ────────────────────────────────────
FROM builder-base-${BACKEND} AS builder-base
# ── Build whisper.cpp (fastest build, run first) ──────────────────────
FROM builder-base AS whisper-build
ARG BACKEND=cuda
ARG WHISPER_COMMIT_HASH=master
COPY install-whisper.sh /build/
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
--mount=type=cache,id=whisper-${BACKEND},target=/src/whisper.cpp/build \
BACKEND=${BACKEND} bash /build/install-whisper.sh "${WHISPER_COMMIT_HASH}"
# ── Build stable-diffusion.cpp ────────────────────────────────────────
FROM builder-base AS sd-build
ARG BACKEND=cuda
ARG SD_COMMIT_HASH=master
COPY install-sd.sh /build/
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
--mount=type=cache,id=sd-${BACKEND},target=/src/stable-diffusion.cpp/build \
BACKEND=${BACKEND} bash /build/install-sd.sh "${SD_COMMIT_HASH}"
# ── Build llama.cpp (slowest build, run last) ─────────────────────────
FROM builder-base AS llama-build
ARG BACKEND=cuda
ARG LLAMA_COMMIT_HASH=master
COPY install-llama.sh /build/
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
--mount=type=cache,id=llama-${BACKEND},target=/src/llama.cpp/build \
BACKEND=${BACKEND} bash /build/install-llama.sh "${LLAMA_COMMIT_HASH}"
# ── Build ik_llama.cpp (CUDA only) ────────────────────────────────────
#
# Two named stages allow ARG BACKEND to select at build time:
# - ik-llama-cuda : real build (from builder-base-cuda)
# - ik-llama-vulkan: no-op (empty /install/bin, skips CUDA pull entirely)
# BuildKit only evaluates the selected branch, so vulkan builds never
# pull nvidia/cuda:*-devel or compile ik_llama.cpp.
FROM builder-base-vulkan AS ik-llama-vulkan
RUN mkdir -p /install/bin
FROM builder-base-cuda AS ik-llama-cuda
ARG IK_LLAMA_COMMIT_HASH=main
COPY install-ik-llama.sh /build/
RUN --mount=type=cache,id=ccache-cuda,target=/ccache \
--mount=type=cache,id=ik-llama-cuda,target=/src/ik_llama.cpp/build \
bash /build/install-ik-llama.sh "${IK_LLAMA_COMMIT_HASH}"
ARG BACKEND=cuda
FROM ik-llama-${BACKEND} AS ik-llama-build
# ── Download llama-swap release binary ────────────────────────────────
FROM builder-base AS llama-swap-download
ARG LS_VERSION=latest
COPY install-llama-swap.sh /build/
RUN bash /build/install-llama-swap.sh "${LS_VERSION}"
# ── Runtime bases ─────────────────────────────────────────────────────
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 AS runtime-cuda
ENV DEBIAN_FRONTEND=noninteractive
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV PATH="/usr/local/bin:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 python3 curl ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# CUDA stub drivers for container compatibility
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
# ──
FROM ubuntu:24.04 AS runtime-vulkan
ENV DEBIAN_FRONTEND=noninteractive
ENV PATH="/usr/local/bin:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 libvulkan1 mesa-vulkan-drivers \
python3 curl ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# ── Select runtime base by BACKEND ────────────────────────────────────
FROM runtime-${BACKEND} AS runtime
ARG BACKEND=cuda
ARG LLAMA_COMMIT_HASH=unknown
ARG WHISPER_COMMIT_HASH=unknown
ARG SD_COMMIT_HASH=unknown
ARG IK_LLAMA_COMMIT_HASH=unknown
ARG RUN_UID=0
RUN apt-get update && apt-get install -y --no-install-recommends \
python3-numpy python3-sentencepiece \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user when RUN_UID != 0
RUN if [ "$RUN_UID" != "0" ]; then \
groupadd --system --gid $RUN_UID llama-swap && \
useradd --system --uid $RUN_UID --gid $RUN_UID \
--home /app --shell /sbin/nologin llama-swap; \
fi && \
mkdir -p /etc/llama-swap/config && \
chown -R ${RUN_UID}:${RUN_UID} /etc/llama-swap
WORKDIR /app
# Copy whisper.cpp binaries and libraries
COPY --from=whisper-build /install/bin/whisper-server /usr/local/bin/
COPY --from=whisper-build /install/bin/whisper-cli /usr/local/bin/
COPY --from=whisper-build /install/lib/ /usr/local/lib/
# Copy stable-diffusion.cpp binaries and libraries
COPY --from=sd-build /install/bin/sd-server /usr/local/bin/
COPY --from=sd-build /install/bin/sd-cli /usr/local/bin/
COPY --from=sd-build /install/lib/ /usr/local/lib/
# Copy llama.cpp binaries (statically linked)
COPY --from=llama-build /install/bin/llama-server /usr/local/bin/
COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
# Copy llama-swap binary
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
RUN ldconfig
COPY config.example.yaml /etc/llama-swap/config/config.yaml
# Version tracking
RUN echo "llama.cpp: ${LLAMA_COMMIT_HASH}" > /versions.txt && \
echo "whisper.cpp: ${WHISPER_COMMIT_HASH}" >> /versions.txt && \
echo "stable-diffusion.cpp: ${SD_COMMIT_HASH}" >> /versions.txt && \
echo "ik_llama.cpp: ${IK_LLAMA_COMMIT_HASH}" >> /versions.txt && \
echo "llama-swap: $(cat /tmp/llama-swap-version)" >> /versions.txt && \
echo "backend: ${BACKEND}" >> /versions.txt && \
echo "build_timestamp: $(date -u +%Y-%m-%dT%H:%M:%SZ)" >> /versions.txt
RUN mkdir -p /models && chown ${RUN_UID}:${RUN_UID} /models
WORKDIR /models
USER ${RUN_UID}
ENTRYPOINT ["llama-swap"]
CMD ["-config", "/etc/llama-swap/config/config.yaml", "-listen", "0.0.0.0:8080"]
+8
View File
@@ -0,0 +1,8 @@
# Unified Docker Container
These scripts create a custom llama-swap container that contains:
- llama-server for LLMs, rerank and embedding model support
- sd-server (stable-diffusion.cpp) for image generation
- whisper.cpp for ASR
+303
View File
@@ -0,0 +1,303 @@
#!/bin/bash
#
# Build script for unified container with version pinning
#
# Usage:
# ./build-image.sh --cuda # Build CUDA image
# ./build-image.sh --vulkan # Build Vulkan image
# ./build-image.sh --cuda --no-cache # Build without cache
# LLAMA_REF=b1234 ./build-image.sh --vulkan # Pin llama.cpp to a commit hash
# LLAMA_REF=v1.2.3 ./build-image.sh --cuda # Pin llama.cpp to a tag
# WHISPER_REF=v1.0.0 ./build-image.sh --vulkan # Pin whisper.cpp to a tag
# SD_REF=master ./build-image.sh --cuda # Pin stable-diffusion.cpp to a branch
# LS_VERSION=170 ./build-image.sh --cuda # Override llama-swap version
# IK_LLAMA_REF=main ./build-image.sh --cuda # Pin ik_llama.cpp to main branch (CUDA only)
#
set -euo pipefail
BACKEND=""
NO_CACHE=false
for arg in "$@"; do
case $arg in
--cuda)
BACKEND="cuda"
;;
--vulkan)
BACKEND="vulkan"
;;
--no-cache)
NO_CACHE=true
;;
--help|-h)
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
echo ""
echo "Options:"
echo " --cuda Build CUDA image (NVIDIA GPUs)"
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
echo " --no-cache Force rebuild without using Docker cache"
echo " --help, -h Show this help message"
echo ""
echo "Environment variables:"
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:unified-cuda or llama-swap:unified-vulkan)"
echo " LLAMA_REF Pin llama.cpp to a commit, tag, or branch"
echo " WHISPER_REF Pin whisper.cpp to a commit, tag, or branch"
echo " SD_REF Pin stable-diffusion.cpp to a commit, tag, or branch"
echo " IK_LLAMA_REF Pin ik_llama.cpp to a commit, tag, or branch (CUDA only)"
echo " LS_VERSION Override llama-swap version (e.g., '170' or 'latest')"
exit 0
;;
esac
done
if [[ -z "$BACKEND" ]]; then
echo "Error: No backend specified. Please use --cuda or --vulkan."
echo ""
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
exit 1
fi
DOCKER_IMAGE_TAG="${DOCKER_IMAGE_TAG:-llama-swap:unified-${BACKEND}}"
# Git repository URLs
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
LLAMA_SWAP_REPO="https://github.com/mostlygeek/llama-swap.git"
IK_LLAMA_REPO="https://github.com/ikawrakow/ik_llama.cpp.git"
# Resolve a git ref (commit hash, tag, or branch) to a full commit hash.
# Requires only: git, network access to the remote.
resolve_ref() {
local repo_url="$1"
local ref="$2"
# Full 40-char SHA — use as-is
if [[ "${ref}" =~ ^[0-9a-f]{40}$ ]]; then
echo "${ref}"
return
fi
# Try tag then branch (exact match)
local hash
hash=$(git ls-remote "${repo_url}" "refs/tags/${ref}" "refs/heads/${ref}" 2>/dev/null | head -1 | cut -f1)
if [[ -n "${hash}" ]]; then
echo "${hash}"
return
fi
# Short hash (7+ chars): scan all refs for a SHA with this prefix
if [[ "${ref}" =~ ^[0-9a-f]{7,}$ ]]; then
hash=$(git ls-remote "${repo_url}" 2>/dev/null | grep "^${ref}" | head -1 | cut -f1)
if [[ -n "${hash}" ]]; then
echo "${hash}"
return
fi
fi
echo "ERROR: Could not resolve ref '${ref}' for ${repo_url}" >&2
if [[ "${ref}" =~ ^[0-9a-f]+$ && ${#ref} -lt 7 ]]; then
echo " Short hashes must be at least 7 characters (got ${#ref})." >&2
else
echo " Tried: tag, branch, git ls-remote prefix match" >&2
fi
echo " Use a full 40-char SHA, a tag name, a branch name, or a 7+ char short hash." >&2
return 1
}
# Resolve HEAD of a repo without needing to know the default branch name.
get_latest_hash() {
git ls-remote "${1}" HEAD 2>/dev/null | head -1 | cut -f1
}
echo "=========================================="
echo "llama-swap Unified Build (${BACKEND})"
echo "=========================================="
echo ""
# Resolve llama.cpp ref
if [[ -n "${LLAMA_REF:-}" ]]; then
LLAMA_HASH=$(resolve_ref "${LLAMA_REPO}" "${LLAMA_REF}") || exit 1
echo "llama.cpp: ${LLAMA_REF} -> ${LLAMA_HASH}"
else
LLAMA_HASH=$(get_latest_hash "${LLAMA_REPO}")
if [[ -z "${LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
exit 1
fi
echo "llama.cpp: latest HEAD: ${LLAMA_HASH}"
fi
# Resolve whisper.cpp ref
if [[ -n "${WHISPER_REF:-}" ]]; then
WHISPER_HASH=$(resolve_ref "${WHISPER_REPO}" "${WHISPER_REF}") || exit 1
echo "whisper.cpp: ${WHISPER_REF} -> ${WHISPER_HASH}"
else
WHISPER_HASH=$(get_latest_hash "${WHISPER_REPO}")
if [[ -z "${WHISPER_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
exit 1
fi
echo "whisper.cpp: latest HEAD: ${WHISPER_HASH}"
fi
# Resolve stable-diffusion.cpp ref
if [[ -n "${SD_REF:-}" ]]; then
SD_HASH=$(resolve_ref "${SD_REPO}" "${SD_REF}") || exit 1
echo "stable-diffusion.cpp: ${SD_REF} -> ${SD_HASH}"
else
SD_HASH=$(get_latest_hash "${SD_REPO}")
if [[ -z "${SD_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
exit 1
fi
echo "stable-diffusion.cpp: latest HEAD: ${SD_HASH}"
fi
# Resolve ik_llama.cpp ref (CUDA only)
if [[ "$BACKEND" == "cuda" ]]; then
if [[ -n "${IK_LLAMA_REF:-}" ]]; then
IK_LLAMA_HASH=$(resolve_ref "${IK_LLAMA_REPO}" "${IK_LLAMA_REF}") || exit 1
echo "ik_llama.cpp: ${IK_LLAMA_REF} -> ${IK_LLAMA_HASH}"
else
IK_LLAMA_HASH=$(get_latest_hash "${IK_LLAMA_REPO}")
if [[ -z "${IK_LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for ik_llama.cpp" >&2
exit 1
fi
echo "ik_llama.cpp: latest HEAD: ${IK_LLAMA_HASH}"
fi
else
IK_LLAMA_HASH="n/a"
echo "ik_llama.cpp: skipped (vulkan build)"
fi
# Resolve llama-swap ref
if [[ -n "${LS_VERSION:-}" ]]; then
LS_HASH=$(resolve_ref "${LLAMA_SWAP_REPO}" "${LS_VERSION}") || exit 1
echo "llama-swap: ${LS_VERSION} -> ${LS_HASH}"
else
LS_HASH=$(get_latest_hash "${LLAMA_SWAP_REPO}")
if [[ -z "${LS_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for llama-swap" >&2
exit 1
fi
echo "llama-swap: latest HEAD: ${LS_HASH}"
fi
echo ""
echo "=========================================="
echo "Starting Docker build..."
echo "=========================================="
echo ""
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BUILD_ARGS=(
--build-arg "BACKEND=${BACKEND}"
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
--build-arg "LS_VERSION=${LS_HASH}"
-t "${DOCKER_IMAGE_TAG}"
-f "${SCRIPT_DIR}/Dockerfile"
)
if [[ "$NO_CACHE" == true ]]; then
BUILD_ARGS+=(--no-cache)
echo "Note: Building without cache"
elif [[ "${GITHUB_ACTIONS:-}" == "true" && "${ACT:-}" != "true" ]]; then
CACHE_REF="ghcr.io/mostlygeek/llama-swap:unified-${BACKEND}-cache"
BUILD_ARGS+=(
--cache-from "type=registry,ref=${CACHE_REF}"
--cache-to "type=registry,ref=${CACHE_REF},mode=max"
)
echo "Note: Using registry cache (${CACHE_REF})"
fi
DOCKER_BUILDKIT=1 docker buildx build --load "${BUILD_ARGS[@]}" "${SCRIPT_DIR}"
echo ""
echo "=========================================="
echo "Verifying build artifacts..."
echo "=========================================="
echo ""
EXPECTED_BINARIES=(llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap)
if [[ "$BACKEND" == "cuda" ]]; then
EXPECTED_BINARIES+=(ik-llama-server)
fi
MISSING_BINARIES=()
for binary in "${EXPECTED_BINARIES[@]}"; do
if ! docker run --rm --entrypoint which "${DOCKER_IMAGE_TAG}" "${binary}" >/dev/null 2>&1; then
MISSING_BINARIES+=("${binary}")
fi
done
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
echo "ERROR: Build succeeded but the following binaries are missing:"
for binary in "${MISSING_BINARIES[@]}"; do
echo " - ${binary}"
done
echo ""
echo "Try running with --no-cache flag:"
echo " ./build-image.sh --${BACKEND} --no-cache"
exit 1
fi
VERIFIED_LIST="llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
if [[ "$BACKEND" == "cuda" ]]; then
VERIFIED_LIST="${VERIFIED_LIST}, ik-llama-server"
fi
echo "All expected binaries verified: ${VERIFIED_LIST}"
echo ""
echo "=========================================="
echo "Building rootless image..."
echo "=========================================="
echo ""
ROOTLESS_TAG="${DOCKER_IMAGE_TAG}-rootless"
docker buildx build --load -t "${ROOTLESS_TAG}" - <<EOF
FROM ${DOCKER_IMAGE_TAG}
USER root
RUN groupadd --system --gid 10001 llama-swap && \\
useradd --system --uid 10001 --gid 10001 \\
--home /app --shell /sbin/nologin llama-swap && \\
chown -R 10001:10001 /etc/llama-swap /models
USER 10001
EOF
echo "Rootless image built: ${ROOTLESS_TAG}"
echo ""
echo "=========================================="
echo "Build complete!"
echo "=========================================="
echo ""
echo "Image tags:"
echo " ${DOCKER_IMAGE_TAG}"
echo " ${ROOTLESS_TAG}"
echo ""
echo "Built with:"
echo " llama.cpp: ${LLAMA_HASH}"
echo " whisper.cpp: ${WHISPER_HASH}"
echo " stable-diffusion.cpp: ${SD_HASH}"
if [[ "$BACKEND" == "cuda" ]]; then
echo " ik_llama.cpp: ${IK_LLAMA_HASH}"
fi
echo " llama-swap: $(docker run --rm --entrypoint cat "${DOCKER_IMAGE_TAG}" /versions.txt | grep llama-swap | cut -d' ' -f2-)"
echo ""
if [[ "$BACKEND" == "vulkan" ]]; then
echo "Run with:"
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
echo ""
echo "Note: For AMD GPUs, you may also need:"
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
else
echo "Run with:"
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
fi
+33
View File
@@ -0,0 +1,33 @@
# placeholder example configuration
healthCheckTimeout: 300
logRequests: true
models:
"llama":
cmd: >
llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port ${PORT}
"whisper":
checkEndpoint: /v1/audio/transcriptions/
cmd: >
whisper-server
--port ${PORT}
--m /models/whisper.bin
--flash-attn
--request-path /v1/audio/transcriptions --inference-path ""
"image":
checkEndpoint: /
cmd: |
/app/sd-server
--listen-port 9999
--diffusion-fa
--diffusion-model /models/z_image_turbo-Q8_0.gguf
--vae /models/ae.safetensors
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
--offload-to-cpu
--cfg-scale 1.0
--height 512 --width 512
--steps 8
+48
View File
@@ -0,0 +1,48 @@
#!/bin/bash
# Install ik_llama.cpp - clone, build, and install binaries
# Usage: ./install-ik-llama.sh <commit_hash>
# Note: CUDA only; always built against builder-base-cuda
set -e
COMMIT_HASH="${1:-main}"
mkdir -p /install/bin
# Clone and checkout (init-based so cache-mounted build dir doesn't break clone)
echo "=== Cloning ik_llama.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/ik_llama.cpp
cd /src/ik_llama.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/ikawrakow/ik_llama.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DGGML_CUDA=ON
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda -Wl,--allow-shlib-undefined"
)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building ik_llama.cpp ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target llama-server
if [ ! -f "build/bin/llama-server" ]; then
echo "FATAL: llama-server not found in build/bin/" >&2
exit 1
fi
# Install as ik-llama-server to avoid collision with llama.cpp's llama-server
cp "build/bin/llama-server" "/install/bin/ik-llama-server"
echo "=== ik_llama.cpp build complete ==="
ls -la /install/bin/
+59
View File
@@ -0,0 +1,59 @@
#!/bin/bash
# Install llama-swap - download latest release binary from GitHub
# Usage: ./install-llama-swap.sh [version]
# version: release version number (e.g., "170") or "latest" (default)
set -e
VERSION="${1:-latest}"
REPO="mostlygeek/llama-swap"
mkdir -p /install/bin
# If a full commit hash is given, find the release tag that points to it
if echo "${VERSION}" | grep -qE '^[0-9a-f]{40}$'; then
echo "=== Resolving commit ${VERSION:0:7} to release tag ==="
TAG=$(git ls-remote --tags "https://github.com/${REPO}.git" 2>/dev/null \
| grep "^${VERSION}" | sed 's|.*refs/tags/||' | grep -v '\^{}' | head -1)
if [ -n "${TAG}" ]; then
echo "Resolved to tag: ${TAG}"
VERSION="${TAG#v}"
else
echo "No release tag found for commit ${VERSION:0:7}, using latest"
VERSION="latest"
fi
fi
# Strip leading 'v' prefix so both "198" and "v198" work
VERSION="${VERSION#v}"
# Resolve "latest" to actual version number
if [ "$VERSION" = "latest" ]; then
echo "=== Resolving latest llama-swap release ==="
VERSION=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" \
| grep '"tag_name"' | head -1 | cut -d'"' -f4 | sed 's/^v//')
if [ -z "$VERSION" ]; then
echo "FATAL: Could not determine latest release version" >&2
exit 1
fi
echo "Latest version: ${VERSION}"
fi
# Download and extract
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz"
echo "=== Downloading llama-swap v${VERSION} ==="
echo "URL: $URL"
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
tar -xzf /tmp/llama-swap.tar.gz -C /install/bin/
rm /tmp/llama-swap.tar.gz
# Validate
if [ ! -x "/install/bin/llama-swap" ]; then
echo "FATAL: llama-swap binary not found or not executable" >&2
ls -la /install/bin/ >&2
exit 1
fi
echo "$VERSION" > /install/llama-swap-version
echo "=== llama-swap v${VERSION} installed ==="
ls -la /install/bin/llama-swap
+63
View File
@@ -0,0 +1,63 @@
#!/bin/bash
# Install llama.cpp - clone, build, and install binaries
# Usage: BACKEND=cuda|vulkan ./install-llama.sh <commit_hash>
set -e
COMMIT_HASH="${1:-master}"
BACKEND="${BACKEND:-cuda}"
mkdir -p /install/bin
# Clone and checkout (init-based so cache-mounted /src/llama.cpp/build dir doesn't break clone)
echo "=== Cloning llama.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/llama.cpp
cd /src/llama.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/ggml-org/llama.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
# Common cmake flags
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DLLAMA_BUILD_TESTS=OFF
)
if [ "$BACKEND" = "cuda" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=ON
-DGGML_VULKAN=OFF
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
)
elif [ "$BACKEND" = "vulkan" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=OFF
-DGGML_VULKAN=ON
)
fi
TARGETS=(llama-cli llama-server)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building llama.cpp for ${BACKEND} ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
for bin in "${TARGETS[@]}"; do
if [ ! -f "build/bin/$bin" ]; then
echo "FATAL: $bin not found in build/bin/" >&2
exit 1
fi
cp "build/bin/$bin" "/install/bin/"
done
echo "=== llama.cpp build complete ==="
ls -la /install/bin/
+68
View File
@@ -0,0 +1,68 @@
#!/bin/bash
# Install stable-diffusion.cpp - clone, build, and install binaries and library
# Usage: BACKEND=cuda|vulkan ./install-sd.sh <commit_hash>
set -e
COMMIT_HASH="${1:-master}"
BACKEND="${BACKEND:-cuda}"
mkdir -p /install/bin /install/lib
# Clone and checkout (init-based so cache-mounted /src/stable-diffusion.cpp/build dir doesn't break clone)
echo "=== Cloning stable-diffusion.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/stable-diffusion.cpp
cd /src/stable-diffusion.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/leejet/stable-diffusion.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
git submodule update --init --recursive --depth=1
# Common cmake flags
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DSD_BUILD_EXAMPLES=ON
)
if [ "$BACKEND" = "cuda" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=ON
-DGGML_VULKAN=OFF
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
-DSD_CUDA=ON
)
elif [ "$BACKEND" = "vulkan" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=OFF
-DGGML_VULKAN=ON
-DSD_VULKAN=ON
)
fi
TARGETS=(stable-diffusion sd-cli sd-server)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building stable-diffusion.cpp for ${BACKEND} ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
for bin in sd-cli sd-server; do
if [ ! -f "build/bin/$bin" ]; then
echo "FATAL: $bin not found in build/bin/" >&2
exit 1
fi
cp "build/bin/$bin" "/install/bin/"
done
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
echo "=== stable-diffusion.cpp build complete ==="
ls -la /install/bin/ /install/lib/
+64
View File
@@ -0,0 +1,64 @@
#!/bin/bash
# Install whisper.cpp - clone, build, and install binaries
# Usage: BACKEND=cuda|vulkan ./install-whisper.sh <commit_hash>
set -e
COMMIT_HASH="${1:-master}"
BACKEND="${BACKEND:-cuda}"
mkdir -p /install/bin /install/lib
# Clone and checkout (init-based so cache-mounted /src/whisper.cpp/build dir doesn't break clone)
echo "=== Cloning whisper.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/whisper.cpp
cd /src/whisper.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/ggml-org/whisper.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
# Common cmake flags
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
)
if [ "$BACKEND" = "cuda" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=ON
-DGGML_VULKAN=OFF
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
)
elif [ "$BACKEND" = "vulkan" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=OFF
-DGGML_VULKAN=ON
)
fi
TARGETS=(whisper-cli whisper-server)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building whisper.cpp for ${BACKEND} ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
for bin in "${TARGETS[@]}"; do
if [ ! -f "build/bin/$bin" ]; then
echo "FATAL: $bin not found in build/bin/" >&2
exit 1
fi
cp "build/bin/$bin" "/install/bin/"
done
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
echo "=== whisper.cpp build complete ==="
ls -la /install/bin/

Before

Width:  |  Height:  |  Size: 261 KiB

After

Width:  |  Height:  |  Size: 261 KiB

Before

Width:  |  Height:  |  Size: 351 KiB

After

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

+582
View File
@@ -0,0 +1,582 @@
# config.yaml
llama-swap is designed to be very simple: one binary, one configuration file.
## minimal viable config
```yaml
models:
model1:
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
```
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
```yaml
models:
model1:
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
model2:
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
model3:
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
```
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model which is useful if you want to run multiple models at the same time with the `matrix` feature.
## Advanced control with `cmd`
llama-swap is also about customizability. You can use any CLI flag available:
```yaml
models:
model1:
cmd: | # support for multi-line
llama-server --PORT ${PORT} -m /path/to/model.gguf
--ctx-size 8192
--jinja
--cache-type-k q8_0
--cache-type-v q8_0
```
## Support for any OpenAI API compatible server
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
```yaml
models:
"Q3-30B-CODER-VLLM":
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
# cmdStop provides a reliable way to stop containers
cmdStop: docker stop vllm-coder
cmd: |
docker run --init --rm --name vllm-coder
--runtime=nvidia --gpus '"device=2,3"'
--shm-size=16g
-v /mnt/nvme/vllm-cache:/root/.cache
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
vllm/vllm-openai:v0.10.0
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
--served-model-name "Q3-30B-CODER-VLLM"
--enable-expert-parallel
--swap-space 16
--max-num-seqs 512
--max-model-len 65536
--max-seq-len-to-capture 65536
--gpu-memory-utilization 0.9
--tensor-parallel-size 2
--trust-remote-code
```
## Many more features..
llama-swap supports many more features to customize how you want to manage your environment.
| Feature | Description |
| --------- | ---------------------------------------------- |
| `ttl` | automatic unloading of models after a timeout |
| `macros` | reusable snippets to use in configurations |
| `matrix` | run multiple models at a time |
| `hooks` | event driven functionality |
| `env` | define environment variables per model |
| `aliases` | serve a model with different names |
| `filters` | modify requests before sending to the upstream |
| `...` | And many more tweaks |
## Full Configuration Example
> [!NOTE]
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
```yaml
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example
# -------------------------------------
#
# 💡 Tip - Use an LLM with this file!
# ====================================
# This example configuration is written to be LLM friendly. Try
# copying this file into an LLM and asking it to explain or generate
# sections for you.
# ====================================
# Usage notes:
# - Below are all the available configuration options for llama-swap.
# - Settings noted as "required" must be in your configuration file
# - Settings noted as "optional" can be omitted
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
# - optional, default: 120
# - minimum value is 15 seconds, anything less will be set to this value
healthCheckTimeout: 500
# logLevel: sets the logging value
# - optional, default: info
# - Valid log levels: debug, info, warn, error
logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# captureBuffer: how many MBs to allocate for storing request/response captures
# - optional, default: 10
# - set to 0 to disable
captureBuffer: 15
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
# - it is automatically incremented for every model that uses it
startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# globalTTL: the default TTL in seconds before unloading a model
# - optional, default: 0 (never automatically unload)
# - must be >= 0
globalTTL: 0
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
# - useful for reducing common configuration settings
# - macro names are strings and must be less than 64 characters
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used
# - environment variables can be referenced with ${env.VAR_NAME} syntax
# - env macros are substituted first, before regular macros
# - if the env var is not set, config loading will fail with an error
macros:
# Example of a multi-line macro
"latest-llama": >
/path/to/llama-server/llama-server-ec9e0301
--port ${PORT}
"default_ctx": 4096
# Example of macro-in-macro usage. macros can contain other macros
# but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}"
# Example of environment variable macros
# - ${env.VAR_NAME} pulls the value from the system environment
# - useful for paths, secrets, or machine-specific configuration
"models_dir": "${env.HOME}/models"
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
# use environment variable macros to keep secrets out of the config
- "${env.API_KEY_1}"
- "${env.API_KEY_2}"
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
# - model settings have default values that are used if they are not defined here
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"gpt-oss-120b":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
# - model level macros follow the same rules as global macros
macros:
"default_ctx": 16384
"temp": 0.7
# cmd: the command to run to start the inference server.
# - required
# - it is just a string, similar to what you would run on the CLI
# - using `|` allows for comments in the command, these will be parsed out
# - macros can be used within cmd
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/gpt-oss-120B.gguf
--ctx-size ${default_ctx}
--temperature ${temp}
# name: a display name for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "gpt-oss 120B"
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A thinking model from OpenAI"
# env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array
# - each value is a single string
# - in the format: ENV_NAME=value
env:
- "CUDA_VISIBLE_DEVICES=0,1,2"
# proxy: the URL where llama-swap routes API requests
# - optional, default: http://localhost:${PORT}
# - if you used ${PORT} in cmd this can be omitted
# - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999
# checkEndpoint: URL path to check if the server is ready
# - optional, default: /health
# - endpoint is expected to return an HTTP 200 response
# - all requests wait until the endpoint is ready or fails
# - use "none" to skip endpoint health checking
checkEndpoint: /custom-endpoint
# ttl: automatically unload the model after ttl seconds
# - optional, default: -1 (use global default)
# - ttl values must be a value greater than or equal to 0
# - a ttl of -1 will use the global TTL value as the default
# - a ttl of 0 will mean never unload
# - a value of 0 disables automatic unloading of the model
ttl: 60
# useModelName: override the model name that is sent to upstream server
# - optional, default: ""
# - useful for when the upstream server expects a specific model name that
# is different from the model's ID
useModelName: "openai/gpt-oss-120B"
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - same capabilities as peer filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for server side enforcement of sampling parameters
# - the `model` parameter can never be removed
# - can be any JSON key in the request body
# - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k"
# setParams: a dictionary of parameters to set/override in requests
# - optional, default: empty dictionary
# - useful for enforcing specific parameter values
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - always runs for the model
setParams:
# Example: enforce specific sampling parameters
temperature: 0.7
top_p: 0.9
# setParamsByID: a dictionary of parameters to set based the model ID
# - optional, default: empty dictionary
# - combine with aliases to create variant behaviour without reloading the model
# - parameters are set in the request body JSON
# - run after setParams so it will override any settings
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - model aliases will be automatically created for each key
setParamsByID:
"${MODEL_ID}":
chat_template_kwargs:
reasoning_effort: medium
"${MODEL_ID}:high":
chat_template_kwargs:
reasoning_effort: high
"${MODEL_ID}:low":
chat_template_kwargs:
reasoning_effort: low
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
# - metadata is only passed through in /v1/models responses
metadata:
# port will remain an integer
port: ${PORT}
# the ${temp} macro will remain a float
temperature: ${temp}
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
a_list:
- 1
- 1.23
- "macros are OK in list and dictionary types: ${MODEL_ID}"
an_obj:
a: "1"
b: 2
# objects can contain complex types with macro substitution
# becomes: c: [0.7, false, "model: llama"]
c: ["${temp}", false, "model: ${MODEL_ID}"]
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0
# - useful for limiting the number of active parallel requests a model can process
# - must be set per model
# - any number greater than 0 will override the internal default value of 10
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
# - recommended to be omitted and the default used
concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below
# - useful for models running on slower hardware that need longer timeouts
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
# - idleConn: idle connection timeout in seconds, default: 90 seconds
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 0
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
# - optional, default: false
# - unlisted models do not show up in /v1/models api requests
# - can be requested as normal through all apis
unlisted: true
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
# Docker example:
# container runtimes like Docker and Podman can be used reliably with
# a combination of cmd, cmdStop, and ${MODEL_ID}
"docker-llama":
proxy: "http://127.0.0.1:${PORT}"
cmd: |
docker run --name ${MODEL_ID}
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
ghcr.io/ggml-org/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# cmdStop: command to run to stop the model gracefully
# - optional, default: ""
# - useful for stopping commands managed by another system
# - the upstream's process id is available in the ${PID} macro
#
# When empty, llama-swap has this default behaviour:
# - on POSIX systems: a SIGTERM signal is sent
# - on Windows, calls taskkill to stop the process
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
# =============================================================================
# matrix: run concurrent models with a solver-based swap DSL
# =============================================================================
#
# Note:
# A config must use either a matrix or legacy groups, not both. A configuration error
# will occur if both are defined. Configuration examples for legacy Groups can be found:
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
#
# The matrix declares valid combinations of models that can run concurrently.
# When a model is requested, the solver finds the cheapest way to make it
# available by evicting as few (and least costly) running models as possible.
#
# Solver behavior:
# 1. Request arrives for model X
# 2. If X is already running, forward immediately. Done.
# 3. Find all sets containing X
# 4. For each candidate set, compute cost: sum of evict_costs for
# every running model NOT in that set
# 5. Pick lowest cost candidate. Ties broken by definition order.
# 6. Evict what needs to stop. Start X. Forward request.
#
# Subset semantics: a set [a, b, c] means any subset is valid.
# Only the requested model is started — others are not preloaded.
#
# A model not appearing in any set can only run alone.
#
matrix:
# vars: short names for models (alphanumeric, 1-8 chars)
# - required for sets and evict_costs settings
# - each entry is a short name to a real model ID. Do not use an alias
# - used to keep set DSL logic short and easier to read
# - sets and evict_costs only use identifiers defined in vars
vars:
g: gemma-model
q: qwen-model
m: mistral-model
v: voxtral-model
e: reranker-model
L: llama-70B
sd: stable-diffusion
# evict_costs: relative cost of losing a running model (default: 1)
evict_costs:
v: 50 # vllm backend, slow cold start
L: 30 # 70B weights, slow to load
# sets: named sets of concurrent model combinations
# Values are DSL strings with operators:
# & AND (models run together)
# | OR (alternatives)
# () grouping
# +ref inline another set's expression
#
# Expansion examples:
# "L" → [L]
# "a & b" → [a, b]
# "a | b" → [a], [b]
# "(a | b) & c" → [a, c], [b, c]
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
# "+llms & v" → expands llms inline, then applies & v
sets:
# LLM + TTS: switching between g/q/m won't evict v
# expands to: [g,v], [q,v], [m,v]
standard: "(g | q | m) & v"
# LLM + TTS + reranker
# expands to: [g,v,e], [q,v,e]
with_rerank: "(g | q) & v & e"
# LLM + image generation, no TTS
# expands to: [g,sd], [q,sd]
creative: "(g | q) & sd"
# 70B model uses all GPUs, can only run alone
# expands to: [L]
full: "L"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
# - the only supported hook is on_startup
hooks:
# on_startup: a dictionary of actions to perform on startup
# - optional, default: empty dictionary
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
# - can be a string or a macro
apiKey: ${env.OPENROUTER_API_KEY}
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
# timeouts: configure proxy connection timeouts for this peer
# - optional, defaults shown below
# - useful when the peer runs on slower hardware
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 30
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# filters: a dictionary of filter settings for peer requests
# - optional, default: empty dictionary
# - same capabilities as model filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for removing parameters that the peer doesn't support
# - the `model` parameter can never be removed
stripParams: "temperature, top_p"
# setParams: a dictionary of parameters to set/override in requests to this peer
# - optional, default: empty dictionary
# - useful for injecting provider-specific settings like data retention policies
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce zero-data-retention for OpenRouter
provider:
data_collection: "deny"
zdr: true
```
+9
View File
@@ -0,0 +1,9 @@
## Container Security
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
+6 -5
View File
@@ -1,11 +1,12 @@
module github.com/mostlygeek/llama-swap
go 1.23.0
go 1.26.1
require (
github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
@@ -37,9 +38,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+10 -8
View File
@@ -34,6 +34,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
@@ -80,16 +82,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
+40 -9
View File
@@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy"
"github.com/mostlygeek/llama-swap/proxy/config"
)
var (
@@ -27,7 +28,9 @@ var (
func main() {
// Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port")
listenStr := flag.String("listen", "", "listen ip/port")
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
keyFile := flag.String("tls-key-file", "", "TLS key file")
showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
@@ -38,13 +41,13 @@ func main() {
os.Exit(0)
}
config, err := proxy.LoadConfig(*configPath)
conf, err := config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
if len(config.Profiles) > 0 {
if len(conf.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}
@@ -54,6 +57,23 @@ func main() {
gin.SetMode(gin.ReleaseMode)
}
// Validate TLS flags.
var useTLS = (*certFile != "" && *keyFile != "")
if (*certFile != "" && *keyFile == "") ||
(*certFile == "" && *keyFile != "") {
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
os.Exit(1)
}
// Set default ports.
if *listenStr == "" {
defaultPort := ":8080"
if useTLS {
defaultPort = ":8443"
}
listenStr = &defaultPort
}
// Setup channels for server management
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1)
@@ -67,7 +87,7 @@ func main() {
// Support for watching config and reloading when it changes
reloadProxyManager := func() {
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
config, err = proxy.LoadConfig(*configPath)
conf, err = config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
return
@@ -75,7 +95,9 @@ func main() {
fmt.Println("Configuration Changed")
currentPM.Shutdown()
srv.Handler = proxy.New(config)
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
srv.Handler = newPM
fmt.Println("Configuration Reloaded")
// wait a few seconds and tell any UI to reload
@@ -85,12 +107,14 @@ func main() {
})
})
} else {
config, err = proxy.LoadConfig(*configPath)
conf, err = config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1)
}
srv.Handler = proxy.New(config)
newPM := proxy.New(conf)
newPM.SetVersion(date, commit, version)
srv.Handler = newPM
}
}
@@ -166,9 +190,16 @@ func main() {
}()
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
var err error
if useTLS {
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
err = srv.ListenAndServeTLS(*certFile, *keyFile)
} else {
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
log.Fatalf("Fatal server error: %v\n", err)
}
}()
Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

-460
View File
@@ -1,460 +0,0 @@
package proxy
import (
"fmt"
"io"
"os"
"regexp"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
)
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"strip_params"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" {
continue
}
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type HooksConfig struct {
OnStartup HookOnStartup `yaml:"on_startup"`
}
type HookOnStartup struct {
Preload []string `yaml:"preload"`
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros map[string]string `yaml:"macros"`
// map aliases to actual model IDs
aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
// hooks, see: #209
Hooks HooksConfig `yaml:"hooks"`
}
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
}
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
if realName, found := c.RealModelName(modelName); !found {
return ModelConfig{}, "", false
} else {
return c.Models[realName], realName, true
}
}
func LoadConfig(path string) (Config, error) {
file, err := os.Open(path)
if err != nil {
return Config{}, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
// default configuration values
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
MetricsMaxInMemory: 1000,
}
err = yaml.Unmarshal(data, &config)
if err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
// set a minimum of 15 seconds
config.HealthCheckTimeout = 15
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
/* check macro constraint rules:
- name must fit the regex ^[a-zA-Z0-9_-]+$
- names must be less than 64 characters (no reason, just cause)
- name can not be any reserved macros: PORT, MODEL_ID
- macro values must be less than 1024 characters
*/
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
for macroName, macroValue := range config.Macros {
if len(macroName) >= 64 {
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
}
if !macroNameRegex.MatchString(macroName) {
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
}
if len(macroValue) >= 1024 {
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
}
switch macroName {
case "PORT":
case "MODEL_ID":
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
}
}
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// Strip comments from command fields before macro expansion
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros {
macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
}
// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
nextPortStr := strconv.Itoa(nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
nextPort++
}
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
}
// make sure there are no unknown macros that have not been replaced
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // this is ok, has to be replaced by process later
}
if _, exists := config.Macros[macroName]; !exists {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
}
config.Models[modelId] = modelConfig
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
// clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
}
func SanitizeCommand(cmdStr string) ([]string, error) {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
+820
View File
@@ -0,0 +1,820 @@
package config
import (
"fmt"
"io"
"net/url"
"os"
"regexp"
"runtime"
"sort"
"strings"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3"
)
const DEFAULT_GROUP_ID = "(default)"
const (
LogToStdoutProxy = "proxy"
LogToStdoutUpstream = "upstream"
LogToStdoutBoth = "both"
LogToStdoutNone = "none"
)
type MacroEntry struct {
Name string
Value any
}
type MacroList []MacroEntry
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("macros must be a mapping")
}
// yaml.Node.Content for a mapping contains alternating key/value nodes
entries := make([]MacroEntry, 0, len(value.Content)/2)
for i := 0; i < len(value.Content); i += 2 {
keyNode := value.Content[i]
valueNode := value.Content[i+1]
var name string
if err := keyNode.Decode(&name); err != nil {
return fmt.Errorf("failed to decode macro name: %w", err)
}
var val any
if err := valueNode.Decode(&val); err != nil {
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
}
entries = append(entries, MacroEntry{Name: name, Value: val})
}
*ml = entries
return nil
}
// Get retrieves a macro value by name
func (ml MacroList) Get(name string) (any, bool) {
for _, entry := range ml {
if entry.Name == name {
return entry.Value, true
}
}
return nil, false
}
// ToMap converts MacroList to a map (for backward compatibility if needed)
func (ml MacroList) ToMap() map[string]any {
result := make(map[string]any, len(ml))
for _, entry := range ml {
result[entry.Name] = entry.Value
}
return result
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
Persistent bool `yaml:"persistent"`
Members []string `yaml:"members"`
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
)
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
defaults := rawGroupConfig{
Swap: true,
Exclusive: true,
Persistent: false,
Members: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
*c = GroupConfig(defaults)
return nil
}
type HooksConfig struct {
OnStartup HookOnStartup `yaml:"on_startup"`
}
type HookOnStartup struct {
Preload []string `yaml:"preload"`
}
type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
LogTimeFormat string `yaml:"logTimeFormat"`
LogToStdout string `yaml:"logToStdout"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
CaptureBuffer int `yaml:"captureBuffer"`
GlobalTTL int `yaml:"globalTTL"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// swap matrix: solver-based alternative to groups
Matrix *MatrixConfig `yaml:"matrix"`
// populated during validation when matrix is configured
ExpandedSets []ExpandedSet `yaml:"-"`
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"`
// map aliases to actual model IDs
aliases map[string]string
// automatic port assignments
StartPort int `yaml:"startPort"`
// hooks, see: #209
Hooks HooksConfig `yaml:"hooks"`
// send loading state in reasoning
SendLoadingState bool `yaml:"sendLoadingState"`
// present aliases to /v1/models OpenAI API listing
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
// support API keys, see issue #433, #50, #251
RequiredAPIKeys []string `yaml:"apiKeys"`
// support remote peers, see issue #433, #296
Peers PeerDictionaryConfig `yaml:"peers"`
}
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
}
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
if realName, found := c.RealModelName(modelName); !found {
return ModelConfig{}, "", false
} else {
return c.Models[realName], realName, true
}
}
func LoadConfig(path string) (Config, error) {
file, err := os.Open(path)
if err != nil {
return Config{}, err
}
defer file.Close()
return LoadConfigFromReader(file)
}
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
yamlStr := string(data)
// Phase 1: Substitute all ${env.VAR} macros at string level
// This is safe because env values are simple strings without YAML formatting
yamlStr, err = substituteEnvMacros(yamlStr)
if err != nil {
return Config{}, err
}
// Unmarshal into full Config with defaults
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
GlobalTTL: 0,
}
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
if config.GlobalTTL < 0 {
return Config{}, fmt.Errorf("globalTTL must be >= 0")
}
switch config.LogToStdout {
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
default:
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
// Validate global macros
for _, macro := range config.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, err
}
}
// Get and sort all model IDs for consistent port assignment
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds)
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// Strip comments from command fields
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// set model TTL to globalTTL it is the default value
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
modelConfig.UnloadAfter = config.GlobalTTL
}
if modelConfig.UnloadAfter < 0 {
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
}
// Validate model macros
for _, macro := range modelConfig.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
mergedMacros = append(mergedMacros, config.Macros...)
// Add model macros (override globals with same name)
for _, entry := range modelConfig.Macros {
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
// Substitute remaining macros in model fields (LIFO order)
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
// Substitute macros in SetParamsByID keys and values
if len(modelConfig.Filters.SetParamsByID) > 0 {
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
for key, paramMap := range modelConfig.Filters.SetParamsByID {
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
}
newParamMap, ok := newValAny.(map[string]any)
if !ok {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
}
newSetParamsByID[newKey] = newParamMap
}
modelConfig.Filters.SetParamsByID = newSetParamsByID
}
// Substitute in metadata (type-preserving)
if len(modelConfig.Metadata) > 0 {
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
}
// Handle PORT macro - only allocate if cmd uses it
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort {
if !cmdHasPort && proxyHasPort {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
macroSlug := "${PORT}"
macroStr := fmt.Sprintf("%v", nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
if len(modelConfig.Metadata) > 0 {
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
nextPort++
}
// Validate no unknown macros remain
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
"name": modelConfig.Name,
"description": modelConfig.Description,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // replaced at runtime
}
if macroName == "PORT" || macroName == "MODEL_ID" {
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
}
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
if len(modelConfig.Metadata) > 0 {
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
return Config{}, err
}
}
// Validate SetParamsByID keys and values
for key, paramMap := range modelConfig.Filters.SetParamsByID {
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
}
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
return Config{}, err
}
}
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
for key := range modelConfig.Filters.SetParamsByID {
if key == modelId {
continue
}
if _, exists := config.Models[key]; exists {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
}
if existingModel, exists := config.aliases[key]; exists {
if existingModel != modelId {
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
}
continue // already registered as explicit alias for this model
}
config.aliases[key] = modelId
modelConfig.Aliases = append(modelConfig.Aliases, key)
}
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
}
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState
modelConfig.SendLoadingState = &v
}
config.Models[modelId] = modelConfig
}
// groups XOR matrix
if config.Matrix != nil && len(config.Groups) > 0 {
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
}
if config.Matrix != nil {
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
if err != nil {
return Config{}, fmt.Errorf("matrix: %w", err)
}
config.ExpandedSets = expandedSets
} else {
config = AddDefaultGroupToConfig(config)
// Validate group members
memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
}
// Clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
// Validate API keys (env macros already substituted at string level)
for i, apikey := range config.RequiredAPIKeys {
if apikey == "" {
return Config{}, fmt.Errorf("empty api key found in apiKeys")
}
if strings.Contains(apikey, " ") {
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
}
config.RequiredAPIKeys[i] = apikey
}
// Process peers with global macro substitution
for peerName, peerConfig := range config.Peers {
// Substitute global macros (LIFO order)
for i := len(config.Macros) - 1; i >= 0; i-- {
entry := config.Macros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in setParams (type-preserving)
if len(peerConfig.Filters.SetParams) > 0 {
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
}
peerConfig.Filters.SetParams = result.(map[string]any)
}
}
// Validate no unknown macros remain
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
}
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
}
if len(peerConfig.Filters.SetParams) > 0 {
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
return Config{}, err
}
}
config.Peers[peerName] = peerConfig
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config {
if config.Groups == nil {
config.Groups = make(map[string]GroupConfig)
}
defaultGroup := GroupConfig{
Swap: true,
Exclusive: true,
Members: []string{},
}
// if groups is empty, create a default group and put
// all models into it
if len(config.Groups) == 0 {
for modelName := range config.Models {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName := range config.Models {
foundModel := false
found:
// search for the model in existing groups
for _, groupConfig := range config.Groups {
for _, member := range groupConfig.Members {
if member == modelName {
foundModel = true
break found
}
}
}
if !foundModel {
defaultGroup.Members = append(defaultGroup.Members, modelName)
}
}
}
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
return config
}
func SanitizeCommand(cmdStr string) ([]string, error) {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
// validateMacro validates macro name and value constraints
func validateMacro(name string, value any) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
// Validate that value is a scalar type
switch v := value.(type) {
case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
// Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
}
switch name {
case "PORT", "MODEL_ID":
return fmt.Errorf("macro name '%s' is reserved", name)
}
return nil
}
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
func validateNestedForUnknownMacros(value any, context string) error {
switch v := value.(type) {
case string:
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
for _, match := range matches {
macroName := match[1]
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
}
// Check for unsubstituted env macros
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
for _, match := range envMatches {
varName := match[1]
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
}
return nil
case map[string]any:
for _, val := range v {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
return nil
case []any:
for _, val := range v {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
return nil
default:
// Scalar types don't contain macros
return nil
}
}
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
// Returns error if any referenced env var is not set or contains invalid characters.
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
// (which strips comments) and only checking the comment-free version for macros.
func substituteEnvMacros(s string) (string, error) {
// Unmarshal and remarshal to strip YAML comments
var raw any
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
// If YAML is invalid, fall back to scanning the original string
// so the user gets the env var error rather than a confusing YAML parse error
return substituteEnvMacrosInString(s, s)
}
clean, err := yaml.Marshal(raw)
if err != nil {
return substituteEnvMacrosInString(s, s)
}
return substituteEnvMacrosInString(s, string(clean))
}
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
// them in target. This separation allows scanning comment-free YAML while
// substituting in the original string.
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
result := target
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
for _, match := range matches {
fullMatch := match[0] // ${env.VAR_NAME}
varName := match[1] // VAR_NAME
value, exists := os.LookupEnv(varName)
if !exists {
return "", fmt.Errorf("environment variable '%s' is not set", varName)
}
// Sanitize the value for safe YAML substitution
value, err := sanitizeEnvValueForYAML(value, varName)
if err != nil {
return "", err
}
result = strings.ReplaceAll(result, fullMatch, value)
}
return result, nil
}
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
// for compatibility with double-quoted YAML strings.
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
// Reject values that would break YAML structure regardless of quoting context
if strings.ContainsAny(value, "\n\r\x00") {
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
}
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
// In double-quoted contexts, they are interpreted correctly.
value = strings.ReplaceAll(value, `\`, `\\`)
value = strings.ReplaceAll(value, `"`, `\"`)
return value, nil
}
@@ -1,6 +1,6 @@
//go:build !windows
package proxy
package config
import (
"os"
@@ -58,6 +58,7 @@ models:
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
assert.Equal(t, "", config.LogTimeFormat)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
@@ -160,51 +161,74 @@ groups:
t.Fatalf("Failed to load config: %v", err)
}
modelLoadingState := false
defaultTimeout := TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
}
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
},
Hooks: HooksConfig{
OnStartup: HookOnStartup{
Preload: []string{"model1", "model2"},
},
},
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model2": {
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/server --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
File diff suppressed because it is too large Load Diff
@@ -1,6 +1,6 @@
//go:build windows
package proxy
package config
import (
"os"
@@ -55,6 +55,7 @@ models:
assert.Equal(t, 120, config.HealthCheckTimeout)
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "info", config.LogLevel)
assert.Equal(t, "", config.LogTimeFormat)
// Test default group exists
defaultGroup, exists := config.Groups["(default)"]
@@ -152,48 +153,71 @@ groups:
t.Fatalf("Failed to load config: %v", err)
}
modelLoadingState := false
defaultTimeout := TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
}
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
"svr-path": "path/to/server",
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
},
SendLoadingState: false,
Models: map[string]ModelConfig{
"model1": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model2": {
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/server --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8081",
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
Cmd: "path/to/cmd --arg1 one",
CmdStop: "taskkill /f /t /pid ${PID}",
Proxy: "http://localhost:8082",
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
+114
View File
@@ -0,0 +1,114 @@
package config
import (
"slices"
"sort"
"strings"
)
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
// These are protected to prevent breaking the proxy's ability to route requests correctly
var ProtectedParams = []string{"model"}
// Filters contains filter settings for modifying request parameters
// Used by both models and peers
type Filters struct {
// StripParams is a comma-separated list of parameters to remove from requests
// The "model" parameter can never be removed
StripParams string `yaml:"stripParams"`
// SetParams is a dictionary of parameters to set/override in requests
// Protected params (like "model") cannot be set
SetParams map[string]any `yaml:"setParams"`
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
// Useful with aliases: a single loaded model can behave differently depending on
// which alias the client used. Applied after SetParams, so it can override those values.
// Protected params (like "model") cannot be set.
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
}
// SanitizedStripParams returns a sorted list of parameters to strip,
// with duplicates, empty strings, and protected params removed
func (f Filters) SanitizedStripParams() []string {
if f.StripParams == "" {
return nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
// Skip protected params, empty strings, and duplicates
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
if len(cleaned) == 0 {
return nil
}
slices.Sort(cleaned)
return cleaned
}
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
// with protected params removed and keys sorted for consistent iteration order.
// Returns nil if the ID has no entry or all its params are protected.
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
if len(f.SetParamsByID) == 0 {
return nil, nil
}
params, found := f.SetParamsByID[requestedModelID]
if !found || len(params) == 0 {
return nil, nil
}
result := make(map[string]any, len(params))
keys := make([]string, 0, len(params))
for key, value := range params {
if slices.Contains(ProtectedParams, key) {
continue
}
result[key] = value
keys = append(keys, key)
}
sort.Strings(keys)
if len(result) == 0 {
return nil, nil
}
return result, keys
}
// SanitizedSetParams returns a copy of SetParams with protected params removed
// and keys sorted for consistent iteration order
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
if len(f.SetParams) == 0 {
return nil, nil
}
result := make(map[string]any, len(f.SetParams))
keys := make([]string, 0, len(f.SetParams))
for key, value := range f.SetParams {
// Skip protected params
if slices.Contains(ProtectedParams, key) {
continue
}
result[key] = value
keys = append(keys, key)
}
// Sort keys for consistent ordering
sort.Strings(keys)
if len(result) == 0 {
return nil, nil
}
return result, keys
}
+285
View File
@@ -0,0 +1,285 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilters_SanitizedStripParams(t *testing.T) {
tests := []struct {
name string
stripParams string
want []string
}{
{
name: "empty string",
stripParams: "",
want: nil,
},
{
name: "single param",
stripParams: "temperature",
want: []string{"temperature"},
},
{
name: "multiple params",
stripParams: "temperature, top_p, top_k",
want: []string{"temperature", "top_k", "top_p"}, // sorted
},
{
name: "model param filtered",
stripParams: "model, temperature, top_p",
want: []string{"temperature", "top_p"},
},
{
name: "only model param",
stripParams: "model",
want: nil,
},
{
name: "duplicates removed",
stripParams: "temperature, top_p, temperature",
want: []string{"temperature", "top_p"},
},
{
name: "extra whitespace",
stripParams: " temperature , top_p ",
want: []string{"temperature", "top_p"},
},
{
name: "empty values filtered",
stripParams: "temperature,,top_p,",
want: []string{"temperature", "top_p"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{StripParams: tt.stripParams}
got := f.SanitizedStripParams()
assert.Equal(t, tt.want, got)
})
}
}
func TestFilters_SanitizedSetParams(t *testing.T) {
tests := []struct {
name string
setParams map[string]any
wantParams map[string]any
wantKeys []string
}{
{
name: "empty setParams",
setParams: nil,
wantParams: nil,
wantKeys: nil,
},
{
name: "empty map",
setParams: map[string]any{},
wantParams: nil,
wantKeys: nil,
},
{
name: "normal params",
setParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantKeys: []string{"temperature", "top_p"},
},
{
name: "protected model param filtered",
setParams: map[string]any{
"model": "should-be-filtered",
"temperature": 0.7,
},
wantParams: map[string]any{
"temperature": 0.7,
},
wantKeys: []string{"temperature"},
},
{
name: "only protected param",
setParams: map[string]any{
"model": "should-be-filtered",
},
wantParams: nil,
wantKeys: nil,
},
{
name: "complex nested values",
setParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantKeys: []string{"provider", "transforms"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{SetParams: tt.setParams}
gotParams, gotKeys := f.SanitizedSetParams()
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
for i, key := range gotKeys {
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
}
if tt.wantParams == nil {
assert.Nil(t, gotParams, "expected nil params")
return
}
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
for key, wantValue := range tt.wantParams {
gotValue, exists := gotParams[key]
assert.True(t, exists, "missing key: %s", key)
// Simple comparison for basic types
switch v := wantValue.(type) {
case string, int, float64, bool:
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
}
}
})
}
}
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
tests := []struct {
name string
setParamsByID map[string]map[string]any
requestedModelID string
wantParams map[string]any
wantKeys []string
}{
{
name: "empty SetParamsByID returns nil",
setParamsByID: nil,
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "empty map returns nil",
setParamsByID: map[string]map[string]any{},
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "non-matching model ID returns nil",
setParamsByID: map[string]map[string]any{
"model2": {"temperature": 0.9},
},
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "matching model ID returns correct params",
setParamsByID: map[string]map[string]any{
"model1": {"temperature": 0.7, "top_p": 0.9},
"model2": {"temperature": 0.5},
},
requestedModelID: "model1",
wantParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantKeys: []string{"temperature", "top_p"},
},
{
name: "protected param model is filtered out",
setParamsByID: map[string]map[string]any{
"model1": {
"model": "should-be-filtered",
"temperature": 0.7,
},
},
requestedModelID: "model1",
wantParams: map[string]any{
"temperature": 0.7,
},
wantKeys: []string{"temperature"},
},
{
name: "only protected param returns nil",
setParamsByID: map[string]map[string]any{
"model1": {
"model": "should-be-filtered",
},
},
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "keys are sorted",
setParamsByID: map[string]map[string]any{
"model1": {
"z_param": "z",
"a_param": "a",
"m_param": "m",
},
},
requestedModelID: "model1",
wantParams: map[string]any{
"z_param": "z",
"a_param": "a",
"m_param": "m",
},
wantKeys: []string{"a_param", "m_param", "z_param"},
},
{
name: "alias style key lookup",
setParamsByID: map[string]map[string]any{
"model1:high": {"reasoning_effort": "high"},
"model1:low": {"reasoning_effort": "low"},
},
requestedModelID: "model1:high",
wantParams: map[string]any{
"reasoning_effort": "high",
},
wantKeys: []string{"reasoning_effort"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{SetParamsByID: tt.setParamsByID}
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
if tt.wantParams == nil {
assert.Nil(t, gotParams)
assert.Nil(t, gotKeys)
return
}
assert.Equal(t, tt.wantKeys, gotKeys)
assert.Equal(t, tt.wantParams, gotParams)
})
}
}
func TestProtectedParams(t *testing.T) {
// Verify that "model" is protected
assert.Contains(t, ProtectedParams, "model")
}
+179
View File
@@ -0,0 +1,179 @@
package config
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// Test macro-in-macro basic substitution
func TestConfig_MacroInMacroBasic(t *testing.T) {
content := `
startPort: 10000
macros:
"A": "value-A"
"B": "prefix-${A}-suffix"
models:
test:
cmd: echo ${B}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
}
// Test LIFO substitution order with 3+ macro levels
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
content := `
startPort: 10000
macros:
"base": "/models"
"path": "${base}/llama"
"full": "${path}/model.gguf"
models:
test:
cmd: load ${full}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
}
// Test MODEL_ID in global macro used by model
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
content := `
startPort: 10000
macros:
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
models:
my-model:
cmd: ${podman-llama} -m model.gguf
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
}
// Test model macro overrides global macro in substitution
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
content := `
startPort: 10000
macros:
"tag": "global"
"msg": "value-${tag}"
models:
test:
macros:
"tag": "model-level"
cmd: echo ${msg}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
}
// Test self-reference detection error
func TestConfig_SelfReferenceDetection(t *testing.T) {
content := `
startPort: 10000
macros:
"recursive": "value-${recursive}"
models:
test:
cmd: echo ${recursive}
proxy: http://localhost:8080
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "recursive")
assert.Contains(t, err.Error(), "self-reference")
}
// Test macro substitution in name and description fields
func TestConfig_MacroInNameAndDescription(t *testing.T) {
content := `
startPort: 10000
macros:
"VARIANT": "Q4_K_M"
"FAMILY": "llama"
models:
my-model:
cmd: echo ok
proxy: http://localhost:8080
name: "${FAMILY} ${VARIANT}"
description: "A ${FAMILY} model in ${VARIANT} format"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
}
// Test MODEL_ID macro in name and description fields
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
content := `
startPort: 10000
models:
llama-3b:
cmd: echo ok
proxy: http://localhost:8080
name: "Model: ${MODEL_ID}"
description: "Running ${MODEL_ID}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
}
// Test unknown macro in name or description returns an error
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
content := `
startPort: 10000
models:
test:
cmd: echo ok
proxy: http://localhost:8080
name: "Model ${UNDEFINED}"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "UNDEFINED")
}
// Test undefined macro reference error
func TestConfig_UndefinedMacroReference(t *testing.T) {
content := `
startPort: 10000
macros:
"A": "value-${UNDEFINED}"
models:
test:
cmd: echo ${A}
proxy: http://localhost:8080
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "UNDEFINED")
}
+226
View File
@@ -0,0 +1,226 @@
package config
import (
"fmt"
"regexp"
"sort"
"gopkg.in/yaml.v3"
)
var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
// MatrixConfig represents the swap matrix configuration block.
type MatrixConfig struct {
Var map[string]string `yaml:"vars"`
EvictCosts map[string]int `yaml:"evict_costs"`
Sets OrderedSets `yaml:"sets"`
}
// SetEntry is a single named set with its DSL expression.
type SetEntry struct {
Name string
DSL string
}
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
type OrderedSets []SetEntry
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("sets must be a mapping")
}
entries := make([]SetEntry, 0, len(value.Content)/2)
for i := 0; i < len(value.Content); i += 2 {
keyNode := value.Content[i]
valueNode := value.Content[i+1]
var name string
if err := keyNode.Decode(&name); err != nil {
return fmt.Errorf("failed to decode set name: %w", err)
}
var dsl string
if err := valueNode.Decode(&dsl); err != nil {
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
}
entries = append(entries, SetEntry{Name: name, DSL: dsl})
}
*os = entries
return nil
}
// ExpandedSet is one valid combination of concurrent models (real model names).
type ExpandedSet struct {
SetName string
DSL string
Models []string // real model names, sorted
}
// ValidateMatrix validates the matrix config and returns all expanded sets.
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
if len(matrix.Sets) == 0 {
return nil, fmt.Errorf("matrix must define at least one set")
}
if len(matrix.Var) == 0 {
return nil, fmt.Errorf("matrix must define at least one var")
}
// Validate var entries
if matrix.Var != nil {
for id, modelName := range matrix.Var {
if !varKeyPattern.MatchString(id) {
return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
}
if _, exists := models[modelName]; !exists {
return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
}
}
}
// Validate evict_costs
if matrix.EvictCosts != nil {
for key, cost := range matrix.EvictCosts {
if cost <= 0 {
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
}
if _, ok := matrix.Var[key]; !ok {
return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
}
}
}
// Build dependency graph for +ref topological sort
setNames := make(map[string]bool)
for _, entry := range matrix.Sets {
setNames[entry.Name] = true
}
deps := make(map[string][]string) // setName -> set names it depends on
for _, entry := range matrix.Sets {
refs, err := extractRefs(entry.DSL)
if err != nil {
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
}
for _, ref := range refs {
if !setNames[ref] {
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
}
}
deps[entry.Name] = refs
}
// Topological sort with cycle detection
order, err := topologicalSort(matrix.Sets, deps)
if err != nil {
return nil, err
}
// Expand sets in topological order
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
var allExpanded []ExpandedSet
totalCombinations := 0
// Build ordered map for efficient lookup
setDSL := make(map[string]string)
for _, entry := range matrix.Sets {
setDSL[entry.Name] = entry.DSL
}
for _, name := range order {
dsl := setDSL[name]
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
if err != nil {
return nil, fmt.Errorf("set %q: %w", name, err)
}
resolvedRefs[name] = combos
// Resolve var IDs to real model names
for _, combo := range combos {
resolved := make([]string, len(combo))
for i, ident := range combo {
realName, ok := matrix.Var[ident]
if !ok {
return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
}
resolved[i] = realName
}
sort.Strings(resolved)
allExpanded = append(allExpanded, ExpandedSet{
SetName: name,
DSL: dsl,
Models: resolved,
})
}
totalCombinations += len(combos)
if totalCombinations > maxDSLExpansions {
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
}
}
return allExpanded, nil
}
// topologicalSort returns set names in dependency order.
// Returns an error if a cycle is detected.
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
// States: 0 = unvisited, 1 = visiting, 2 = visited
state := make(map[string]int)
var order []string
var visit func(name string) error
visit = func(name string) error {
switch state[name] {
case 1:
return fmt.Errorf("circular reference detected involving set %q", name)
case 2:
return nil
}
state[name] = 1
for _, dep := range deps[name] {
if err := visit(dep); err != nil {
return err
}
}
state[name] = 2
order = append(order, name)
return nil
}
// Visit in definition order for deterministic output
for _, entry := range sets {
if state[entry.Name] == 0 {
if err := visit(entry.Name); err != nil {
return nil, err
}
}
}
return order, nil
}
// ResolvedEvictCosts returns a map of real model name -> evict cost,
// resolving var IDs. Models not listed default to 1.
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
costs := make(map[string]int)
if m.EvictCosts == nil {
return costs
}
for key, cost := range m.EvictCosts {
// Resolve var ID if present
if realName, ok := m.Var[key]; ok {
costs[realName] = cost
} else {
costs[key] = cost
}
}
return costs
}
+376
View File
@@ -0,0 +1,376 @@
package config
import (
"fmt"
"sort"
"strings"
"unicode"
)
const maxDSLExpansions = 1000
// Token types for the DSL lexer
type tokenType int
const (
tokIdent tokenType = iota // model alias or name
tokAnd // &
tokOr // |
tokLParen // (
tokRParen // )
tokRef // +setName
tokEOF
)
type token struct {
typ tokenType
val string
}
// tokenize splits a DSL string into tokens.
func tokenize(input string) ([]token, error) {
var tokens []token
i := 0
runes := []rune(input)
for i < len(runes) {
ch := runes[i]
// skip whitespace
if unicode.IsSpace(ch) {
i++
continue
}
switch ch {
case '&':
tokens = append(tokens, token{tokAnd, "&"})
i++
case '|':
tokens = append(tokens, token{tokOr, "|"})
i++
case '(':
tokens = append(tokens, token{tokLParen, "("})
i++
case ')':
tokens = append(tokens, token{tokRParen, ")"})
i++
case '+':
// +ref: read the identifier that follows
i++
start := i
for i < len(runes) && isIdentChar(runes[i]) {
i++
}
if i == start {
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
}
tokens = append(tokens, token{tokRef, string(runes[start:i])})
default:
if isIdentChar(ch) {
start := i
for i < len(runes) && isIdentChar(runes[i]) {
i++
}
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
} else {
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
}
}
}
tokens = append(tokens, token{tokEOF, ""})
return tokens, nil
}
func isIdentChar(ch rune) bool {
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
}
// AST node types
type dslNode interface {
dslNode()
}
type andNode struct {
children []dslNode
}
type orNode struct {
children []dslNode
}
type leafNode struct {
name string
}
type refNode struct {
setName string
}
func (andNode) dslNode() {}
func (orNode) dslNode() {}
func (leafNode) dslNode() {}
func (refNode) dslNode() {}
// parser holds state for recursive-descent parsing.
type parser struct {
tokens []token
pos int
}
func (p *parser) peek() token {
if p.pos < len(p.tokens) {
return p.tokens[p.pos]
}
return token{tokEOF, ""}
}
func (p *parser) next() token {
t := p.peek()
if t.typ != tokEOF {
p.pos++
}
return t
}
func (p *parser) expect(typ tokenType) (token, error) {
t := p.next()
if t.typ != typ {
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
}
return t, nil
}
// Grammar:
//
// expr = andExpr
// andExpr = orExpr ('&' orExpr)*
// orExpr = atom ('|' atom)*
// atom = ident | '+' ident | '(' expr ')'
//
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
func parse(tokens []token) (dslNode, error) {
p := &parser{tokens: tokens}
node, err := p.parseExpr()
if err != nil {
return nil, err
}
if p.peek().typ != tokEOF {
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
}
return node, nil
}
func (p *parser) parseExpr() (dslNode, error) {
return p.parseOrExpr()
}
func (p *parser) parseOrExpr() (dslNode, error) {
left, err := p.parseAndExpr()
if err != nil {
return nil, err
}
if p.peek().typ == tokOr {
children := []dslNode{left}
for p.peek().typ == tokOr {
p.next() // consume |
right, err := p.parseAndExpr()
if err != nil {
return nil, err
}
children = append(children, right)
}
return orNode{children: children}, nil
}
return left, nil
}
func (p *parser) parseAndExpr() (dslNode, error) {
left, err := p.parseAtom()
if err != nil {
return nil, err
}
if p.peek().typ == tokAnd {
children := []dslNode{left}
for p.peek().typ == tokAnd {
p.next() // consume &
right, err := p.parseAtom()
if err != nil {
return nil, err
}
children = append(children, right)
}
return andNode{children: children}, nil
}
return left, nil
}
func (p *parser) parseAtom() (dslNode, error) {
t := p.peek()
switch t.typ {
case tokIdent:
p.next()
return leafNode{name: t.val}, nil
case tokRef:
p.next()
return refNode{setName: t.val}, nil
case tokLParen:
p.next() // consume (
node, err := p.parseExpr()
if err != nil {
return nil, err
}
if _, err := p.expect(tokRParen); err != nil {
return nil, fmt.Errorf("missing closing parenthesis")
}
return node, nil
default:
return nil, fmt.Errorf("unexpected token %q", t.val)
}
}
// expand walks the AST and produces all combinations.
// resolvedRefs contains previously expanded sets for +ref resolution.
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
switch n := node.(type) {
case leafNode:
return [][]string{{n.name}}, nil
case refNode:
expanded, ok := resolvedRefs[n.setName]
if !ok {
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
}
// Return a copy
result := make([][]string, len(expanded))
for i, combo := range expanded {
result[i] = make([]string, len(combo))
copy(result[i], combo)
}
return result, nil
case orNode:
// Union of all children's expansions
var result [][]string
for _, child := range n.children {
childResult, err := expand(child, resolvedRefs)
if err != nil {
return nil, err
}
result = append(result, childResult...)
if len(result) > maxDSLExpansions {
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
}
}
return result, nil
case andNode:
// Cartesian product across children
result := [][]string{{}} // start with one empty combo
for _, child := range n.children {
childResult, err := expand(child, resolvedRefs)
if err != nil {
return nil, err
}
result, err = cartesianProduct(result, childResult, maxDSLExpansions)
if err != nil {
return nil, err
}
}
return result, nil
default:
return nil, fmt.Errorf("unknown node type %T", node)
}
}
// cartesianProduct computes the cartesian product of two sets of combinations.
// It returns an error if the product would exceed cap.
func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
if int64(len(left))*int64(len(right)) > int64(cap) {
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
}
result := make([][]string, 0, len(left)*len(right))
for _, l := range left {
for _, r := range right {
combo := make([]string, 0, len(l)+len(r))
combo = append(combo, l...)
combo = append(combo, r...)
result = append(result, combo)
}
}
return result, nil
}
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
// resolvedRefs contains previously expanded sets for +ref inlining.
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
dsl = strings.TrimSpace(dsl)
if dsl == "" {
return nil, fmt.Errorf("empty DSL expression")
}
tokens, err := tokenize(dsl)
if err != nil {
return nil, fmt.Errorf("tokenize: %w", err)
}
tree, err := parse(tokens)
if err != nil {
return nil, fmt.Errorf("parse: %w", err)
}
result, err := expand(tree, resolvedRefs)
if err != nil {
return nil, err
}
// Deduplicate models within each combination and sort for consistency
for i, combo := range result {
result[i] = dedupAndSort(combo)
}
return result, nil
}
// dedupAndSort removes duplicate entries and sorts alphabetically.
func dedupAndSort(items []string) []string {
seen := make(map[string]bool, len(items))
var unique []string
for _, item := range items {
if !seen[item] {
seen[item] = true
unique = append(unique, item)
}
}
sort.Strings(unique)
return unique
}
// extractRefs scans a DSL string for +ref tokens without full parsing.
// Used for building the dependency graph for topological sorting.
func extractRefs(dsl string) ([]string, error) {
tokens, err := tokenize(dsl)
if err != nil {
return nil, err
}
var refs []string
seen := make(map[string]bool)
for _, t := range tokens {
if t.typ == tokRef && !seen[t.val] {
seen[t.val] = true
refs = append(refs, t.val)
}
}
return refs, nil
}
+300
View File
@@ -0,0 +1,300 @@
package config
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDSL_Tokenize(t *testing.T) {
tests := []struct {
name string
input string
expect []token
errMsg string
}{
{
name: "single identifier",
input: "abc",
expect: []token{
{tokIdent, "abc"},
{tokEOF, ""},
},
},
{
name: "identifier with hyphens and dots",
input: "model-name.v2",
expect: []token{
{tokIdent, "model-name.v2"},
{tokEOF, ""},
},
},
{
name: "and expression",
input: "a & b",
expect: []token{
{tokIdent, "a"},
{tokAnd, "&"},
{tokIdent, "b"},
{tokEOF, ""},
},
},
{
name: "or expression",
input: "a | b",
expect: []token{
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokEOF, ""},
},
},
{
name: "parentheses",
input: "(a | b) & c",
expect: []token{
{tokLParen, "("},
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokRParen, ")"},
{tokAnd, "&"},
{tokIdent, "c"},
{tokEOF, ""},
},
},
{
name: "ref token",
input: "+llms & v",
expect: []token{
{tokRef, "llms"},
{tokAnd, "&"},
{tokIdent, "v"},
{tokEOF, ""},
},
},
{
name: "no whitespace",
input: "(a|b)&c",
expect: []token{
{tokLParen, "("},
{tokIdent, "a"},
{tokOr, "|"},
{tokIdent, "b"},
{tokRParen, ")"},
{tokAnd, "&"},
{tokIdent, "c"},
{tokEOF, ""},
},
},
{
name: "empty ref",
input: "+",
errMsg: "expected set name after '+'",
},
{
name: "invalid character",
input: "a @ b",
errMsg: "unexpected character",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens, err := tokenize(tt.input)
if tt.errMsg != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expect, tokens)
}
})
}
}
func TestDSL_ParseAndExpand(t *testing.T) {
tests := []struct {
name string
dsl string
refs map[string][][]string
expect [][]string
errMsg string
}{
{
name: "single model",
dsl: "L",
expect: [][]string{{"L"}},
},
{
name: "two models with AND",
dsl: "a & b",
expect: [][]string{{"a", "b"}},
},
{
name: "two models with OR",
dsl: "a | b",
expect: [][]string{{"a"}, {"b"}},
},
{
name: "three models with OR",
dsl: "a | b | c",
expect: [][]string{{"a"}, {"b"}, {"c"}},
},
{
name: "cartesian product (a|b) & (c|d)",
dsl: "(a | b) & (c | d)",
expect: [][]string{
{"a", "c"},
{"a", "d"},
{"b", "c"},
{"b", "d"},
},
},
{
name: "three-way AND",
dsl: "a & b & c",
expect: [][]string{
{"a", "b", "c"},
},
},
{
name: "(g | q | m) & v",
dsl: "(g | q | m) & v",
expect: [][]string{
{"g", "v"},
{"q", "v"},
{"m", "v"},
},
},
{
name: "(g | q) & v & e",
dsl: "(g | q) & v & e",
expect: [][]string{
{"e", "g", "v"},
{"e", "q", "v"},
},
},
{
name: "precedence: a | b & c means a | (b & c)",
dsl: "a | b & c",
expect: [][]string{
{"a"},
{"b", "c"},
},
},
{
name: "+ref inlining",
dsl: "+llms & v",
refs: map[string][][]string{
"llms": {{"g"}, {"q"}, {"m"}},
},
expect: [][]string{
{"g", "v"},
{"q", "v"},
{"m", "v"},
},
},
{
name: "+ref chained",
dsl: "+with_tts & e",
refs: map[string][][]string{
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
},
expect: [][]string{
{"e", "g", "v"},
{"e", "q", "v"},
{"e", "m", "v"},
},
},
{
name: "dedup within combination",
dsl: "a & a",
expect: [][]string{
{"a"},
},
},
{
name: "empty expression",
dsl: "",
errMsg: "empty DSL expression",
},
{
name: "unmatched open paren",
dsl: "(a | b",
errMsg: "missing closing parenthesis",
},
{
name: "unmatched close paren",
dsl: "a | b)",
errMsg: "unexpected token",
},
{
name: "unknown ref",
dsl: "+unknown",
errMsg: "unknown set reference +unknown",
},
{
name: "empty parens",
dsl: "()",
errMsg: "unexpected token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
refs := tt.refs
if refs == nil {
refs = map[string][][]string{}
}
result, err := ParseAndExpandDSL(tt.dsl, refs)
if tt.errMsg != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expect, result)
}
})
}
}
func TestDSL_ExpansionCap(t *testing.T) {
// Build an expression that would exceed 1000 combinations:
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
var aItems, bItems []string
for i := 0; i < 32; i++ {
aItems = append(aItems, fmt.Sprintf("a%d", i))
bItems = append(bItems, fmt.Sprintf("b%d", i))
}
dsl := fmt.Sprintf("(%s) & (%s)",
join(aItems, " | "),
join(bItems, " | "),
)
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeded")
}
func TestDSL_ExtractRefs(t *testing.T) {
refs, err := extractRefs("+llms & v & +other")
require.NoError(t, err)
assert.Equal(t, []string{"llms", "other"}, refs)
refs, err = extractRefs("a & b")
require.NoError(t, err)
assert.Empty(t, refs)
}
func join(items []string, sep string) string {
result := ""
for i, item := range items {
if i > 0 {
result += sep
}
result += item
}
return result
}
+305
View File
@@ -0,0 +1,305 @@
package config
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeModels(names ...string) map[string]ModelConfig {
m := make(map[string]ModelConfig)
for _, name := range names {
m[name] = ModelConfig{Cmd: "echo " + name}
}
return m
}
func TestValidateMatrix_Basic(t *testing.T) {
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
matrix := MatrixConfig{
Var: map[string]string{
"g": "gemma",
"q": "qwen",
"m": "mistral",
"v": "voxtral",
"L": "llama70B",
},
EvictCosts: map[string]int{
"L": 30,
"v": 50,
},
Sets: OrderedSets{
{Name: "standard", DSL: "(g | q | m) & v"},
{Name: "full", DSL: "L"},
},
}
expanded, err := ValidateMatrix(matrix, models)
require.NoError(t, err)
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
// full expands to [llama70B]
assert.Len(t, expanded, 4)
assert.Equal(t, "standard", expanded[0].SetName)
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
assert.Equal(t, "standard", expanded[1].SetName)
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
assert.Equal(t, "standard", expanded[2].SetName)
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
assert.Equal(t, "full", expanded[3].SetName)
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
}
func TestValidateMatrix_WithRef(t *testing.T) {
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
matrix := MatrixConfig{
Var: map[string]string{
"g": "gemma",
"q": "qwen",
"m": "mistral",
"v": "voxtral",
"e": "reranker",
},
Sets: OrderedSets{
{Name: "llms", DSL: "g | q | m"},
{Name: "with_tts", DSL: "+llms & v"},
{Name: "mega", DSL: "+with_tts & e"},
},
}
expanded, err := ValidateMatrix(matrix, models)
require.NoError(t, err)
// llms: [gemma], [qwen], [mistral]
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
assert.Len(t, expanded, 9)
// Check mega entries
megaEntries := filterBySetName(expanded, "mega")
assert.Len(t, megaEntries, 3)
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
}
func TestValidateMatrix_MapIDRequired(t *testing.T) {
// DSL cannot use real model names directly — must use var IDs
models := makeModels("gemma", "voxtral")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "combo", DSL: "g & voxtral"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown var ID")
}
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
models := makeModels("gemma")
tests := []struct {
name string
alias string
errMsg string
}{
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{tt.alias: "gemma"},
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
})
}
}
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"x": "nonexistent"},
Sets: OrderedSets{{Name: "s", DSL: "x"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown model")
}
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
models := makeModels("gemma")
t.Run("zero cost", func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"g": 0},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "positive integer")
})
t.Run("negative cost", func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"g": -1},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "positive integer")
})
t.Run("unknown var ID in evict_costs", func(t *testing.T) {
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
EvictCosts: map[string]int{"unknown": 5},
Sets: OrderedSets{{Name: "s", DSL: "g"}},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown var ID")
})
}
func TestValidateMatrix_CycleDetection(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "a", DSL: "+b"},
{Name: "b", DSL: "+a"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "circular reference")
}
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "a", DSL: "+nonexistent"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "references undefined set")
}
func TestValidateMatrix_NoSets(t *testing.T) {
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one set")
}
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
models := makeModels("gemma")
matrix := MatrixConfig{
Var: map[string]string{"g": "gemma"},
Sets: OrderedSets{
{Name: "s", DSL: "g & nonexistent"},
},
}
_, err := ValidateMatrix(matrix, models)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown var ID")
}
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
mc := &MatrixConfig{
Var: map[string]string{
"g": "gemma",
"L": "llama70B",
},
EvictCosts: map[string]int{
"L": 30,
"g": 5,
},
}
costs := mc.ResolvedEvictCosts()
assert.Equal(t, 30, costs["llama70B"])
assert.Equal(t, 5, costs["gemma"])
}
func TestValidateMatrix_ConfigXOR(t *testing.T) {
// groups and matrix both defined
yaml := `
models:
model1:
cmd: echo model1
proxy: http://localhost:8080
groups:
group1:
members:
- model1
matrix:
sets:
s: "model1"
`
_, err := LoadConfigFromReader(strings.NewReader(yaml))
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot use both")
}
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
yaml := `
models:
gemma:
cmd: echo gemma
proxy: http://localhost:8080
qwen:
cmd: echo qwen
proxy: http://localhost:8081
matrix:
vars:
g: gemma
q: qwen
sets:
combo: "g | q"
`
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
require.NoError(t, err)
assert.NotNil(t, cfg.Matrix)
assert.Len(t, cfg.ExpandedSets, 2)
// Groups should be empty when matrix is used
assert.Empty(t, cfg.Groups)
}
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
var result []ExpandedSet
for _, s := range sets {
if s.SetName == name {
result = append(result, s)
}
}
return result
}
+136
View File
@@ -0,0 +1,136 @@
package config
import (
"errors"
"runtime"
)
const (
MODEL_CONFIG_DEFAULT_TTL = -1
)
// TimeoutsConfig holds timeout settings for proxy connections
// 0 = no timeout
type TimeoutsConfig struct {
Connect int `yaml:"connect"`
KeepAlive int `yaml:"keepalive"`
ResponseHeader int `yaml:"responseHeader"`
TLSHandshake int `yaml:"tlsHandshake"`
ExpectContinue int `yaml:"expectContinue"`
IdleConn int `yaml:"idleConn"`
}
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
// Macros: see #264
// Model level macros take precedence over the global macros
Macros MacroList `yaml:"macros"`
// Metadata: see #264
// Arbitrary metadata that can be exposed through the API
Metadata map[string]any `yaml:"metadata"`
// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`
// Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
// matches http.DefaultTransport
Timeouts: TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters embeds Filters and adds legacy support for strip_params field
// See issue #174
type ModelFilters struct {
Filters `yaml:",inline"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{}
if err := unmarshal(&defaults); err != nil {
return err
}
// Try to unmarshal with the old field name for backwards compatibility
if defaults.StripParams == "" {
var legacy struct {
StripParams string `yaml:"strip_params"`
}
if legacyErr := unmarshal(&legacy); legacyErr != nil {
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
}
defaults.StripParams = legacy.StripParams
}
*m = ModelFilters(defaults)
return nil
}
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
// Returns ([]string, error) to match existing API
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
return f.Filters.SanitizedStripParams(), nil
}
+172
View File
@@ -0,0 +1,172 @@
package config
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
# macros inserted and list is cleaned of duplicates and empty strings
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
# check for strip_params (legacy field name) compatibility
legacy:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
for modelId, modelConfig := range config.Models {
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
// model has been removed
// empty strings have been removed
// duplicates have been removed
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
})
}
}
func TestConfig_ModelSendLoadingState(t *testing.T) {
content := `
sendLoadingState: true
models:
model1:
cmd: path/to/cmd --port ${PORT}
sendLoadingState: false
model2:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.True(t, config.SendLoadingState)
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
assert.False(t, *config.Models["model1"].SendLoadingState)
}
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
assert.True(t, *config.Models["model2"].SendLoadingState)
}
}
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
"${MODEL_ID}:high":
reasoning_effort: high
"${MODEL_ID}:low":
reasoning_effort: low
`
cfg, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Keys (other than the model's own ID) should be registered as aliases
realName, found := cfg.RealModelName("model1:high")
assert.True(t, found, "model1:high should be an auto-registered alias")
assert.Equal(t, "model1", realName)
realName, found = cfg.RealModelName("model1:low")
assert.True(t, found, "model1:low should be an auto-registered alias")
assert.Equal(t, "model1", realName)
// Auto-aliases should also appear in modelConfig.Aliases
aliases := cfg.Models["model1"].Aliases
assert.Contains(t, aliases, "model1:high")
assert.Contains(t, aliases, "model1:low")
}
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
model2:
reasoning_effort: high
model2:
cmd: path/to/cmd --port ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.ErrorContains(t, err, "conflicts with an existing model ID")
}
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
"shared-alias":
reasoning_effort: high
model2:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
"shared-alias":
reasoning_effort: low
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.ErrorContains(t, err, "duplicate alias")
}
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
stripParams: "top_k"
setParams:
temperature: 0.7
top_p: 0.9
stop:
- "<|end|>"
- "<|stop|>"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig := config.Models["model1"]
// Check stripParams
stripParams, err := modelConfig.Filters.SanitizedStripParams()
assert.NoError(t, err)
assert.Equal(t, []string{"top_k"}, stripParams)
// Check setParams
setParams, keys := modelConfig.Filters.SanitizedSetParams()
assert.NotNil(t, setParams)
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
assert.Equal(t, 0.7, setParams["temperature"])
assert.Equal(t, 0.9, setParams["top_p"])
}
+63
View File
@@ -0,0 +1,63 @@
package config
import (
"fmt"
"net/url"
)
type PeerDictionaryConfig map[string]PeerConfig
type PeerConfig struct {
Proxy string `yaml:"proxy"`
ProxyURL *url.URL `yaml:"-"`
ApiKey string `yaml:"apiKey"`
Models []string `yaml:"models"`
Filters Filters `yaml:"filters"`
// Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"`
}
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPeerConfig PeerConfig
defaults := rawPeerConfig{
Proxy: "",
ApiKey: "",
Models: []string{},
Filters: Filters{},
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
// to match the pre PR #619 functionality
Timeouts: TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}
if err := unmarshal(&defaults); err != nil {
return err
}
// Validate proxy is not empty
if defaults.Proxy == "" {
return fmt.Errorf("proxy is required")
}
// Validate proxy is a valid URL and store the parsed value
parsedURL, err := url.Parse(defaults.Proxy)
if err != nil {
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
}
defaults.ProxyURL = parsedURL
// Validate models is not empty
if len(defaults.Models) == 0 {
return fmt.Errorf("peer models can not be empty")
}
*c = PeerConfig(defaults)
return nil
}
+209
View File
@@ -0,0 +1,209 @@
package config
import (
"testing"
"gopkg.in/yaml.v3"
)
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
tests := []struct {
name string
yaml string
wantErr string
}{
{
name: "valid config",
yaml: `
proxy: http://192.168.1.23
models:
- model_a
- model_b
`,
wantErr: "",
},
{
name: "valid config with apiKey",
yaml: `
proxy: https://openrouter.ai/api
apiKey: sk-test-key
models:
- meta-llama/llama-3.1-8b-instruct
`,
wantErr: "",
},
{
name: "missing proxy",
yaml: `
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "empty proxy",
yaml: `
proxy: ""
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "invalid proxy URL",
yaml: `
proxy: "://invalid"
models:
- model_a
`,
wantErr: "invalid peer proxy URL",
},
{
name: "missing models",
yaml: `
proxy: http://localhost:8080
`,
wantErr: "peer models can not be empty",
},
{
name: "empty models",
yaml: `
proxy: http://localhost:8080
models: []
`,
wantErr: "peer models can not be empty",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var config PeerConfig
err := yaml.Unmarshal([]byte(tt.yaml), &config)
if tt.wantErr == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.wantErr)
} else if !contains(err.Error(), tt.wantErr) {
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
}
}
})
}
}
func TestPeerConfig_ProxyURL(t *testing.T) {
yamlData := `
proxy: http://192.168.1.23:8080/api
apiKey: sk-test
models:
- model_a
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.ProxyURL == nil {
t.Fatal("ProxyURL should not be nil")
}
if config.ProxyURL.Host != "192.168.1.23:8080" {
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
}
if config.ProxyURL.Scheme != "http" {
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
}
if config.ProxyURL.Path != "/api" {
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchSubstring(s, substr)
}
func searchSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestPeerConfig_WithFilters(t *testing.T) {
yamlData := `
proxy: https://openrouter.ai/api
apiKey: sk-test
models:
- model_a
filters:
setParams:
temperature: 0.7
provider:
data_collection: deny
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.Filters.SetParams == nil {
t.Fatal("Filters.SetParams should not be nil")
}
if config.Filters.SetParams["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
}
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
if !ok {
t.Fatal("provider should be a map")
}
if provider["data_collection"] != "deny" {
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
}
}
func TestPeerConfig_WithBothFilters(t *testing.T) {
yamlData := `
proxy: https://openrouter.ai/api
apiKey: sk-test
models:
- model_a
filters:
stripParams: "temperature, top_p"
setParams:
max_tokens: 1000
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check stripParams
stripParams := config.Filters.SanitizedStripParams()
if len(stripParams) != 2 {
t.Errorf("expected 2 strip params, got %d", len(stripParams))
}
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
t.Errorf("unexpected strip params: %v", stripParams)
}
// Check setParams
if config.Filters.SetParams == nil {
t.Fatal("Filters.SetParams should not be nil")
}
if config.Filters.SetParams["max_tokens"] != 1000 {
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
}
}
-483
View File
@@ -1,483 +0,0 @@
package proxy
import (
"slices"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_GroupMemberIsUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
model3:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
healthCheckTimeout: 15
groups:
group1:
swap: true
exclusive: false
members: ["model2"]
group2:
swap: true
exclusive: false
members: ["model2"]
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// a Contains as order of the map is not guaranteed
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
}
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
aliases:
- m1
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
checkEndpoint: "/"
aliases:
- m1
- m2
`
// Load the config and verify
_, err := LoadConfigFromReader(strings.NewReader(content))
// this is a contains because it could be `model1` or `model2` depending on the order
// go decided on the order of the map
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
}
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_FindConfig(t *testing.T) {
// TODO?
// make make this shared between the different tests
config := &Config{
Models: map[string]ModelConfig{
"model1": {
Cmd: "python model1.py",
Proxy: "http://localhost:8080",
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
},
"model2": {
Cmd: "python model2.py",
Proxy: "http://localhost:8081",
Aliases: []string{"m2", "model-two"},
Env: []string{"VAR3=value3", "VAR4=value4"},
CheckEndpoint: "/status",
},
},
HealthCheckTimeout: 10,
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
}
// Test finding a model by its name
modelConfig, modelId, found := config.FindConfig("model1")
assert.True(t, found)
assert.Equal(t, "model1", modelId)
assert.Equal(t, config.Models["model1"], modelConfig)
// Test finding a model by its alias
modelConfig, modelId, found = config.FindConfig("m1")
assert.True(t, found)
assert.Equal(t, "model1", modelId)
assert.Equal(t, config.Models["model1"], modelConfig)
// Test finding a model that does not exist
modelConfig, modelId, found = config.FindConfig("model3")
assert.False(t, found)
assert.Equal(t, "", modelId)
assert.Equal(t, ModelConfig{}, modelConfig)
}
func TestConfig_AutomaticPortAssignments(t *testing.T) {
t.Run("Default Port Ranges", func(t *testing.T) {
content := ``
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
})
t.Run("User specific port ranges", func(t *testing.T) {
content := `startPort: 1000`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 1000, config.StartPort)
})
t.Run("Invalid start port", func(t *testing.T) {
content := `startPort: abcd`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("start port must be greater than 1", func(t *testing.T) {
content := `startPort: -99`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NotNil(t, err)
})
t.Run("Automatic port assignments", func(t *testing.T) {
content := `
startPort: 5800
models:
model1:
cmd: svr --port ${PORT}
model2:
cmd: svr --port ${PORT}
proxy: "http://172.11.22.33:${PORT}"
model3:
cmd: svr --port 1999
proxy: "http://1.2.3.4:1999"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
if !assert.NoError(t, err) {
t.Fatalf("Failed to load config: %v", err)
}
assert.Equal(t, 5800, config.StartPort)
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
})
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
content := `
models:
model1:
cmd: svr --port 111
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Equal(t, "model model1: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", err.Error())
})
}
func TestConfig_MacroReplacement(t *testing.T) {
content := `
startPort: 9990
macros:
svr-path: "path/to/server"
argOne: "--arg1"
argTwo: "--arg2"
autoPort: "--port ${PORT}"
models:
model1:
cmd: |
${svr-path} ${argTwo}
# the automatic ${PORT} is replaced
${autoPort}
${argOne}
--arg3 three
cmdStop: |
/path/to/stop.sh --port ${PORT} ${argTwo}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
assert.NoError(t, err)
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
}
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
tests := []struct {
name string
field string
content string
}{
{
name: "unknown macro in cmd",
field: "cmd",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: |
${svr-path} --port ${PORT}
${unknownMacro}
`,
},
{
name: "unknown macro in cmdStop",
field: "cmdStop",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
cmdStop: "kill ${unknownMacro}"
`,
},
{
name: "unknown macro in proxy",
field: "proxy",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
proxy: "http://localhost:${unknownMacro}"
`,
},
{
name: "unknown macro in checkEndpoint",
field: "checkEndpoint",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
checkEndpoint: "http://localhost:${unknownMacro}/health"
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
//t.Log(err)
})
}
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig, ok := config.Models["model1"]
if !assert.True(t, ok) {
t.FailNow()
}
// make sure `model` and enmpty strings are not in the list
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
}
func TestStripComments(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no comments",
input: "echo hello\necho world",
expected: "echo hello\necho world",
},
{
name: "single comment line",
input: "# this is a comment\necho hello",
expected: "echo hello",
},
{
name: "multiple comment lines",
input: "# comment 1\necho hello\n# comment 2\necho world",
expected: "echo hello\necho world",
},
{
name: "comment with spaces",
input: " # indented comment\necho hello",
expected: "echo hello",
},
{
name: "empty lines preserved",
input: "echo hello\n\necho world",
expected: "echo hello\n\necho world",
},
{
name: "only comments",
input: "# comment 1\n# comment 2",
expected: "",
},
{
name: "empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StripComments(tt.input)
if result != tt.expected {
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
}
})
}
}
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
// Test case that reproduces the original bug where a macro in a comment
// would get expanded and cause the comment text to be included in the command
content := `
startPort: 9990
macros:
"latest-llama": >
/user/llama.cpp/build/bin/llama-server
--port ${PORT}
models:
"test-model":
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model /path/to/model.gguf
-ngl 99
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Get the sanitized command
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
assert.NoError(t, err)
// Join the command for easier inspection
cmdStr := strings.Join(sanitizedCmd, " ")
// Verify that comment text is NOT present in the final command as separate arguments
commentWords := []string{"is", "macro", "that", "defined", "above"}
for _, word := range commentWords {
found := slices.Contains(sanitizedCmd, word)
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
}
// Verify that the actual command components ARE present
expectedParts := []string{
"/user/llama.cpp/build/bin/llama-server",
"--port",
"9990",
"--model",
"/path/to/model.gguf",
"-ngl",
"99",
}
for _, part := range expectedParts {
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
}
// Verify the server path appears exactly once (not duplicated due to macro expansion)
serverPath := "/user/llama.cpp/build/bin/llama-server"
count := strings.Count(cmdStr, serverPath)
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
// Verify the expected final command structure
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
}
func TestConfig_MacroModelId(t *testing.T) {
content := `
startPort: 9000
macros:
"docker-llama": docker run --name ${MODEL_ID} -p ${PORT}:8080 docker_img
"docker-stop": docker stop ${MODEL_ID}
models:
model1:
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
model2:
cmd: ${docker-llama}
cmdStop: ${docker-stop}
author/model:F16:
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
cmdStop: stop
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
assert.Equal(t, "docker stop ${MODEL_ID}", config.Macros["docker-stop"])
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "docker run --name model2 -p 9002:8080 docker_img", strings.Join(sanitizedCmd2, " "))
sanitizedCmdStop, err := SanitizeCommand(config.Models["model2"].CmdStop)
assert.NoError(t, err)
assert.Equal(t, "docker stop model2", strings.Join(sanitizedCmdStop, " "))
sanitizedCmd3, err := SanitizeCommand(config.Models["author/model:F16"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
}
+9
View File
@@ -8,6 +8,7 @@ const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05
const ModelPreloadedEventID = 0x06
const InFlightRequestsEventID = 0x07
type ProcessStateChangeEvent struct {
ProcessName string
@@ -58,3 +59,11 @@ type ModelPreloadedEvent struct {
func (e ModelPreloadedEvent) Type() uint32 {
return ModelPreloadedEventID
}
type InFlightRequestsEvent struct {
Total int
}
func (e InFlightRequestsEvent) Type() uint32 {
return InFlightRequestsEventID
}
+211 -4
View File
@@ -1,14 +1,22 @@
package proxy
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v3"
)
@@ -65,21 +73,220 @@ func getTestPort() int {
return port
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
// Convert path to forward slashes for cross-platform compatibility
// Windows handles forward slashes in paths correctly
cmdPath := filepath.ToSlash(simpleResponderPath)
// Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d"
`, simpleResponderPath, port, expectedMessage, port)
`, cmdPath, port, expectedMessage, port)
var cfg ModelConfig
var cfg config.ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
}
return cfg
}
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
// model IDs to their respond strings; if a model ID is not in the map, the model
// ID itself is used.
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
for _, pg := range pm.processGroups {
for modelID, process := range pg.processes {
respond := modelID
if r, ok := modelResponses[modelID]; ok {
respond = r
}
process.testHandler = newTestHandler(respond)
}
}
}
// newTestHandler returns an http.Handler that mimics simple-responder's API.
// It supports the endpoints that routing tests depend on, without launching
// any subprocess or binding any port.
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
isStreaming := r.URL.Query().Get("stream") == "true"
if wait := r.URL.Query().Get("wait"); wait != "" {
if d, err := time.ParseDuration(wait); err == nil {
time.Sleep(d)
}
}
if isStreaming {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher := w.(http.Flusher)
for i := 0; i < 10; i++ {
data, _ := json.Marshal(map[string]any{
"created": time.Now().Unix(),
"choices": []map[string]any{
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
},
})
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
flusher.Flush()
}
finalData, _ := json.Marshal(map[string]any{
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
"timings": map[string]any{
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
"predicted_ms": 17, "predicted_per_second": 10,
},
})
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
flusher.Flush()
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
flusher.Flush()
} else {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"h_content_length": r.Header.Get("Content-Length"),
"request_body": string(bodyBytes),
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
"timings": map[string]any{
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
"predicted_ms": 17, "predicted_per_second": 10,
},
})
}
})
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
if modelName != respond {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
return
}
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
})
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
})
})
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
})
})
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
model := r.FormValue("model")
if model == "" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
return
}
file, _, err := r.FormFile("file")
if err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
fileBytes, _ := io.ReadAll(file)
file.Close()
json.NewEncoder(w).Encode(map[string]any{
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
"model": model,
"h_content_type": r.Header.Get("Content-Type"),
"h_content_length": r.Header.Get("Content-Length"),
})
})
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
model := r.URL.Query().Get("model")
json.NewEncoder(w).Encode(map[string]any{
"voices": []string{"voice1"}, "model": model,
})
})
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprint(w, respond)
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
})
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
json.NewEncoder(w).Encode(map[string]any{
"model": modelName, "images": []string{},
})
})
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
json.NewEncoder(w).Encode(map[string]any{
"model": modelName, "images": []string{},
})
})
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"loras": []string{},
})
})
return mux
}
+121 -21
View File
@@ -1,16 +1,95 @@
package proxy
import (
"container/ring"
"context"
"fmt"
"io"
"os"
"sync"
"time"
"github.com/mostlygeek/llama-swap/event"
)
// circularBuffer is a fixed-size circular byte buffer that overwrites
// oldest data when full. It provides O(1) writes and O(n) reads.
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
func newCircularBuffer(capacity int) *circularBuffer {
return &circularBuffer{
data: make([]byte, capacity),
head: 0,
size: 0,
}
}
// Write appends bytes to the buffer, overwriting oldest data when full.
// Data is copied into the internal buffer (not stored by reference).
func (cb *circularBuffer) Write(p []byte) {
if len(p) == 0 {
return
}
cap := len(cb.data)
// If input is larger than capacity, only keep the last cap bytes
if len(p) >= cap {
copy(cb.data, p[len(p)-cap:])
cb.head = 0
cb.size = cap
return
}
// Calculate how much space is available from head to end of buffer
firstPart := cap - cb.head
if firstPart >= len(p) {
// All data fits without wrapping
copy(cb.data[cb.head:], p)
cb.head = (cb.head + len(p)) % cap
} else {
// Data wraps around
copy(cb.data[cb.head:], p[:firstPart])
copy(cb.data[:len(p)-firstPart], p[firstPart:])
cb.head = len(p) - firstPart
}
// Update size
cb.size += len(p)
if cb.size > cap {
cb.size = cap
}
}
// GetHistory returns all buffered data in correct order (oldest to newest).
// Returns a new slice (copy), not a view into internal buffer.
func (cb *circularBuffer) GetHistory() []byte {
if cb.size == 0 {
return nil
}
result := make([]byte, cb.size)
cap := len(cb.data)
// Calculate start position (oldest data)
start := (cb.head - cb.size + cap) % cap
if start+cb.size <= cap {
// Data is contiguous, single copy
copy(result, cb.data[start:start+cb.size])
} else {
// Data wraps around, two copies
firstPart := cap - start
copy(result[:firstPart], cb.data[start:])
copy(result[firstPart:], cb.data[:cb.size-firstPart])
}
return result
}
type LogLevel int
const (
@@ -18,12 +97,14 @@ const (
LevelInfo
LevelWarn
LevelError
LogBufferSize = 100 * 1024
)
type LogMonitor struct {
eventbus *event.Dispatcher
mu sync.RWMutex
buffer *ring.Ring
buffer *circularBuffer
bufferMu sync.RWMutex
// typically this can be os.Stdout
@@ -32,6 +113,9 @@ type LogMonitor struct {
// logging levels
level LogLevel
prefix string
// timestamps
timeFormat string
}
func NewLogMonitor() *LogMonitor {
@@ -40,11 +124,12 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout,
level: LevelInfo,
prefix: "",
eventbus: event.NewDispatcherConfig(1000),
buffer: nil, // lazy initialized on first Write
stdout: stdout,
level: LevelInfo,
prefix: "",
timeFormat: "",
}
}
@@ -59,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
}
w.bufferMu.Lock()
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.buffer.Value = bufferCopy
w.buffer = w.buffer.Next()
if w.buffer == nil {
w.buffer = newCircularBuffer(LogBufferSize)
}
w.buffer.Write(p)
w.bufferMu.Unlock()
// Make a copy for broadcast to preserve immutability
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.broadcast(bufferCopy)
return n, nil
}
@@ -72,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
func (w *LogMonitor) GetHistory() []byte {
w.bufferMu.RLock()
defer w.bufferMu.RUnlock()
if w.buffer == nil {
return nil
}
return w.buffer.GetHistory()
}
var history []byte
w.buffer.Do(func(p any) {
if p != nil {
if content, ok := p.([]byte); ok {
history = append(history, content...)
}
}
})
return history
// Clear releases the buffer memory, making it eligible for GC.
// The buffer will be lazily re-allocated on the next Write.
func (w *LogMonitor) Clear() {
w.bufferMu.Lock()
w.buffer = nil
w.bufferMu.Unlock()
}
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
@@ -106,12 +196,22 @@ func (w *LogMonitor) SetLogLevel(level LogLevel) {
w.level = level
}
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
w.mu.Lock()
defer w.mu.Unlock()
w.timeFormat = timeFormat
}
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
prefix := ""
if w.prefix != "" {
prefix = fmt.Sprintf("[%s] ", w.prefix)
}
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
timestamp := ""
if w.timeFormat != "" {
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
}
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
}
func (w *LogMonitor) log(level LogLevel, msg string) {
+230
View File
@@ -3,8 +3,10 @@ package proxy
import (
"bytes"
"io"
"strings"
"sync"
"testing"
"time"
)
func TestLogMonitor(t *testing.T) {
@@ -84,3 +86,231 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
t.Errorf("Expected history to be %q, got %q", expected, history)
}
}
func TestWrite_LogTimeFormat(t *testing.T) {
// Create a new LogMonitor instance
lm := NewLogMonitorWriter(io.Discard)
// Enable timestamps
lm.timeFormat = time.RFC3339
// Write the message to the LogMonitor
lm.Info("Hello, World!")
// Get the history from the LogMonitor
history := lm.GetHistory()
timestamp := ""
fields := strings.Fields(string(history))
if len(fields) > 0 {
timestamp = fields[0]
} else {
t.Fatalf("Cannot extract string from history")
}
_, err := time.Parse(time.RFC3339, timestamp)
if err != nil {
t.Fatalf("Cannot find timestamp: %v", err)
}
}
func TestCircularBuffer_WrapAround(t *testing.T) {
// Create a small buffer to test wrap-around
cb := newCircularBuffer(10)
// Write "hello" (5 bytes)
cb.Write([]byte("hello"))
if got := string(cb.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Write "world" (5 bytes) - buffer now full
cb.Write([]byte("world"))
if got := string(cb.GetHistory()); got != "helloworld" {
t.Errorf("Expected 'helloworld', got %q", got)
}
// Write "12345" (5 bytes) - should overwrite "hello"
cb.Write([]byte("12345"))
if got := string(cb.GetHistory()); got != "world12345" {
t.Errorf("Expected 'world12345', got %q", got)
}
// Write data larger than buffer capacity
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
if got := string(cb.GetHistory()); got != "ghijklmnop" {
t.Errorf("Expected 'ghijklmnop', got %q", got)
}
}
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
// Test empty buffer
cb := newCircularBuffer(10)
if got := cb.GetHistory(); got != nil {
t.Errorf("Expected nil for empty buffer, got %q", got)
}
// Test exact capacity
cb.Write([]byte("1234567890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
// Test write exactly at capacity boundary
cb = newCircularBuffer(10)
cb.Write([]byte("12345"))
cb.Write([]byte("67890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
}
func TestLogMonitor_LazyInit(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Buffer should be nil before any writes
if lm.buffer != nil {
t.Error("Expected buffer to be nil before first write")
}
// GetHistory should return nil when buffer is nil
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history before first write, got %q", got)
}
// Write should lazily initialize the buffer
lm.Write([]byte("test"))
if lm.buffer == nil {
t.Error("Expected buffer to be initialized after write")
}
if got := string(lm.GetHistory()); got != "test" {
t.Errorf("Expected 'test', got %q", got)
}
}
func TestLogMonitor_Clear(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write some data
lm.Write([]byte("hello"))
if got := string(lm.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Clear should release the buffer
lm.Clear()
if lm.buffer != nil {
t.Error("Expected buffer to be nil after Clear")
}
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history after Clear, got %q", got)
}
}
func TestLogMonitor_ClearAndReuse(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write, clear, then write again
lm.Write([]byte("first"))
lm.Clear()
lm.Write([]byte("second"))
if got := string(lm.GetHistory()); got != "second" {
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
}
}
func BenchmarkLogMonitorWrite(b *testing.B) {
// Test data of varying sizes
smallMsg := []byte("small message\n")
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
b.Run("SmallWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(smallMsg)
}
})
b.Run("MediumWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("LargeWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(largeMsg)
}
})
b.Run("WithSubscribers", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Add some subscribers
for i := 0; i < 5; i++ {
lm.OnLogData(func(data []byte) {})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("GetHistory", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Pre-populate with data
for i := 0; i < 1000; i++ {
lm.Write(mediumMsg)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.GetHistory()
}
})
}
/*
Benchmark Results - MBP M1 Pro
Before (ring.Ring):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
| MediumWrite (241B) | 76 ns | 264 B | 2 |
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
After (circularBuffer 10KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 67 ns | 240 B | 1 |
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
After (circularBuffer 100KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|-----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 66 ns | 240 B | 1 |
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
Summary:
- GetHistory: 139x faster (10KB), 18x faster (100KB)
- Allocations: reduced from 2 to 1 across all operations
- Small/medium writes: ~1.1-1.6x faster
*/
+329
View File
@@ -0,0 +1,329 @@
package proxy
import (
"fmt"
"net/http"
"slices"
"sort"
"sync"
"github.com/mostlygeek/llama-swap/proxy/config"
)
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
// It is safe for concurrent reads after construction.
type MatrixSolver struct {
expandedSets []config.ExpandedSet // all valid model combinations
evictCosts map[string]int // real model name -> eviction cost (default 1)
modelToSets map[string][]int // model name -> indices into expandedSets
}
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
modelToSets := make(map[string][]int)
for i, es := range expandedSets {
for _, model := range es.Models {
modelToSets[model] = append(modelToSets[model], i)
}
}
return &MatrixSolver{
expandedSets: expandedSets,
evictCosts: evictCosts,
modelToSets: modelToSets,
}
}
// SolveResult describes what the solver decided.
type SolveResult struct {
Evict []string // running models that must be stopped
TargetSet []string // the chosen set of models (for informational purposes)
SetName string // name of the chosen set
DSL string // original DSL expression for the chosen set
TotalCost int // total eviction cost
}
// Solve determines which models to evict when a model is requested.
//
// Algorithm:
// 1. If requestedModel is already running, no eviction needed.
// 2. Find all sets containing requestedModel.
// 3. If no sets found, the model runs alone; evict all running models.
// 4. For each candidate set, compute cost = sum of evict_costs for running
// models NOT in that set.
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
// 6. Return models to evict and the chosen set.
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
// If already running, nothing to do (but fill in set info for logging)
if slices.Contains(runningModels, requestedModel) {
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
return SolveResult{
TargetSet: runningModels,
SetName: setName,
DSL: dsl,
}, nil
}
candidateIndices := s.modelToSets[requestedModel]
// Model not in any set: runs alone, evict everything
if len(candidateIndices) == 0 {
evict := make([]string, len(runningModels))
copy(evict, runningModels)
return SolveResult{
Evict: evict,
TargetSet: []string{requestedModel},
}, nil
}
// Find the cheapest candidate set
bestCost := -1
bestIdx := -1
for _, idx := range candidateIndices {
setModels := s.expandedSets[idx].Models
cost := 0
for _, running := range runningModels {
if !slices.Contains(setModels, running) {
cost += s.evictCost(running)
}
}
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
bestCost = cost
bestIdx = idx
}
}
// Determine which running models to evict
chosen := s.expandedSets[bestIdx]
var evict []string
for _, running := range runningModels {
if !slices.Contains(chosen.Models, running) {
evict = append(evict, running)
}
}
return SolveResult{
Evict: evict,
TargetSet: chosen.Models,
SetName: chosen.SetName,
DSL: chosen.DSL,
TotalCost: bestCost,
}, nil
}
// findMatchingSet finds the expanded set that contains all running models.
// Returns the set name and DSL, or empty strings if no match.
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
for _, idx := range s.modelToSets[requestedModel] {
set := s.expandedSets[idx]
allInSet := true
for _, m := range runningModels {
if !slices.Contains(set.Models, m) {
allInSet = false
break
}
}
if allInSet {
return set.SetName, set.DSL
}
}
return "", ""
}
func (s *MatrixSolver) evictCost(model string) int {
if cost, ok := s.evictCosts[model]; ok {
return cost
}
return 1
}
// Matrix manages processes using solver-based swap logic.
type Matrix struct {
sync.Mutex
solver *MatrixSolver
processes map[string]*Process // all processes keyed by real model name
config config.Config
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
// inflight tracks ProxyRequest calls that have released m.Lock but may
// not yet have incremented Process.inFlightRequests. A concurrent
// request that needs to evict models waits for inflight to drain under
// m.Lock before stopping anything. Without this, a request that
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
// races with Stop()'s Wait() and can be killed mid-request.
inflight sync.WaitGroup
// testDelayFastPath is a test-only hook invoked in the no-eviction path
// after m.Lock is released but before the request is dispatched to
// Process.ProxyRequest. Tests use it to park a request at the exact
// race window to deterministically reproduce the race.
testDelayFastPath func()
}
// NewMatrix creates a Matrix from config. It creates a Process for every
// model defined in the config (any model can run alone even if not in a set).
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
processes := make(map[string]*Process)
for modelID, modelConfig := range cfg.Models {
processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
processes[modelID] = process
}
evictCosts := cfg.Matrix.ResolvedEvictCosts()
return &Matrix{
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
processes: processes,
config: cfg,
proxyLogger: proxyLogger,
upstreamLogger: upstreamLogger,
}
}
// ProxyRequest handles the swap logic and proxies the request to the model.
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
process, ok := m.processes[modelID]
if !ok {
return fmt.Errorf("model %s not found in matrix", modelID)
}
m.Lock()
running := m.runningModels()
result, err := m.solver.Solve(modelID, running)
if err != nil {
m.Unlock()
return fmt.Errorf("matrix solver error: %w", err)
}
// Log solver decision
if len(result.Evict) > 0 {
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
} else if len(running) == 0 {
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
} else {
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
}
// Evict models that need to be stopped
if len(result.Evict) > 0 {
// Wait for any in-flight ProxyRequest calls to register on their
// Process before stopping anything. Without this, a request that
// released m.Lock but has not yet incremented
// Process.inFlightRequests races with Stop() and can be killed
// mid-request.
m.inflight.Wait()
var wg sync.WaitGroup
for _, evictModel := range result.Evict {
if p, exists := m.processes[evictModel]; exists {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
p.Stop()
}(p)
}
}
wg.Wait()
}
// Register this request in inflight before releasing m.Lock so a
// concurrent eviction will wait for it to complete.
m.inflight.Add(1)
defer m.inflight.Done()
isFastPath := len(result.Evict) == 0
m.Unlock()
if isFastPath && m.testDelayFastPath != nil {
m.testDelayFastPath()
}
// Proxy the request (Process handles on-demand start)
process.ProxyRequest(w, r)
return nil
}
// StopProcesses stops all running processes.
func (m *Matrix) StopProcesses(strategy StopStrategy) {
m.Lock()
defer m.Unlock()
var wg sync.WaitGroup
for _, process := range m.processes {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
switch strategy {
case StopImmediately:
p.StopImmediately()
default:
p.Stop()
}
}(process)
}
wg.Wait()
}
// StopProcess stops a single process by model ID.
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
process, ok := m.processes[modelID]
if !ok {
return fmt.Errorf("process not found for %s", modelID)
}
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
return nil
}
// Shutdown shuts down all processes.
func (m *Matrix) Shutdown() {
var wg sync.WaitGroup
for _, process := range m.processes {
wg.Add(1)
go func(p *Process) {
defer wg.Done()
p.Shutdown()
}(process)
}
wg.Wait()
}
// RunningModels returns model names currently in an active (non-stopped) state.
func (m *Matrix) RunningModels() []string {
m.Lock()
defer m.Unlock()
return m.runningModels()
}
// runningModels returns running model names (caller must hold lock).
func (m *Matrix) runningModels() []string {
var running []string
for id, process := range m.processes {
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
running = append(running, id)
}
}
sort.Strings(running)
return running
}
// GetProcess returns the Process for a model.
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
p, ok := m.processes[modelID]
return p, ok
}
// HasModel returns true if the model is managed by this matrix.
func (m *Matrix) HasModel(modelID string) bool {
_, ok := m.processes[modelID]
return ok
}
+349
View File
@@ -0,0 +1,349 @@
package proxy
import (
"net/http"
"net/http/httptest"
"runtime"
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper to build expanded sets for solver tests
func makeExpandedSets(sets ...struct {
name string
models []string
}) []config.ExpandedSet {
var result []config.ExpandedSet
for _, s := range sets {
result = append(result, config.ExpandedSet{
SetName: s.name,
Models: s.models,
})
}
return result
}
func es(name string, models ...string) struct {
name string
models []string
} {
return struct {
name string
models []string
}{name, models}
}
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("a", []string{"a"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a"}, result.TargetSet)
assert.Equal(t, "s1", result.SetName)
}
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
// Model "c" not in any set
result, err := solver.Solve("c", []string{"a", "b"})
require.NoError(t, err)
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
assert.Equal(t, []string{"c"}, result.TargetSet)
}
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("c", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"c"}, result.TargetSet)
}
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
// Set: [a, b]. Request a when b and c are running.
solver := NewMatrixSolver(
makeExpandedSets(es("s1", "a", "b")),
nil,
)
result, err := solver.Solve("a", []string{"b", "c"})
require.NoError(t, err)
// c is not in the set, so it gets evicted. b is in the set, so it stays.
assert.Equal(t, []string{"c"}, result.Evict)
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
}
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
// Two sets containing model "a":
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "a", "v"),
es("s2", "a", "L"),
),
map[string]int{"v": 50, "L": 30},
)
// v is running. Switching to a:
// s1 cost: v is in s1, so 0
// s2 cost: v is NOT in s2, so 50
// => pick s1
result, err := solver.Solve("a", []string{"v"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
// L is running. Switching to a:
// s1 cost: L is NOT in s1, so 30
// s2 cost: L is in s2, so 0
// => pick s2
result, err = solver.Solve("a", []string{"L"})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
}
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
// Two sets with identical cost. Definition order should win.
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "a", "x"),
es("s2", "a", "y"),
),
nil,
)
// Nothing running, both sets cost 0. s1 is first.
result, err := solver.Solve("a", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
}
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
// Model "v" costs 50 to evict, "m" costs 1 (default).
// Sets: [g,v], [g,m]
// Running: v, m. Request g.
// s1=[g,v]: evict m (cost 1), keep v
// s2=[g,m]: evict v (cost 50), keep m
// => pick s1
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "g", "v"),
es("s2", "g", "m"),
),
map[string]int{"v": 50},
)
result, err := solver.Solve("g", []string{"v", "m"})
require.NoError(t, err)
assert.Equal(t, []string{"m"}, result.Evict)
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
}
func TestMatrixSolver_NothingRunning(t *testing.T) {
solver := NewMatrixSolver(
makeExpandedSets(
es("s1", "g", "v"),
es("s2", "q", "v"),
),
nil,
)
result, err := solver.Solve("g", []string{})
require.NoError(t, err)
assert.Empty(t, result.Evict)
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
}
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
// cannot stop a process while an in-flight ProxyRequest for that process is
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
// matrix-level inflight tracking, the eviction's Stop() races with the
// pending request and kills it mid-start.
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
cfg := config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
ExpandedSets: []config.ExpandedSet{
{SetName: "s1", Models: []string{"model1"}},
{SetName: "s2", Models: []string{"model2"}},
},
Matrix: &config.MatrixConfig{},
}
m := NewMatrix(cfg, testLogger, testLogger)
defer m.StopProcesses(StopImmediately)
// Bypass real subprocesses so the test is fast and deterministic.
m.processes["model1"].testHandler = newTestHandler("model1")
m.processes["model2"].testHandler = newTestHandler("model2")
// Prime: run a request through model1 so it reaches StateReady and
// subsequent requests take the no-eviction path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
// Install fast-path hook that signals arrival and waits for release.
// This parks R2 at the race window — after m.Lock is released but
// before Process.inFlightRequests.Add(1).
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
m.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
// R2: no-eviction request for model1. Will pause at the hook.
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, m.ProxyRequest("model1", w2, req))
}()
// Deterministically wait for R2 to reach the race window.
<-r2Reached
// R3: request for model2 which requires evicting model1. Must wait for
// R2 to finish before touching model1.
r3Done := make(chan struct{})
w3 := httptest.NewRecorder()
go func() {
defer close(r3Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, m.ProxyRequest("model2", w3, req))
}()
// Spin until R3 has acquired m.Lock and entered the eviction path. In
// the fixed code, R3 then blocks on m.inflight.Wait() while still
// holding the lock, so TryLock keeps failing.
for m.TryLock() {
m.Unlock()
runtime.Gosched()
}
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
// state. In the fixed code R3 is blocked and nothing changes; in the
// buggy code R3 will Stop() model1 and start model2 within microseconds.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if m.processes["model1"].CurrentState() != StateReady ||
m.processes["model2"].CurrentState() != StateStopped {
break
}
done := false
select {
case <-r3Done:
done = true
default:
}
if done {
break
}
runtime.Gosched()
}
// Invariant: R3 must be blocked while R2 is still in flight.
select {
case <-r3Done:
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
default:
}
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
"model1 must stay Ready while an in-flight request is pending")
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
"model2 must not be started until R2 finishes and model1 is evicted")
// Release R2 and let both requests finish.
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
assert.Equal(t, http.StatusOK, w3.Code)
assert.Contains(t, w3.Body.String(), "model2")
}
func TestMatrixSolver_FullScenario(t *testing.T) {
// Simulates the example config:
// standard: [g,v], [q,v], [m,v]
// with_rerank: [g,v,e], [q,v,e]
// creative: [g,sd], [q,sd]
// full: [L]
solver := NewMatrixSolver(
makeExpandedSets(
es("standard", "g", "v"),
es("standard", "q", "v"),
es("standard", "m", "v"),
es("with_rerank", "e", "g", "v"),
es("with_rerank", "e", "q", "v"),
es("creative", "g", "sd"),
es("creative", "q", "sd"),
es("full", "L"),
),
map[string]int{"v": 50, "L": 30, "whisper": 10},
)
// Running: g, v. Request q.
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
// => tie, pick first by definition order = standard[q,v]
result, err := solver.Solve("q", []string{"g", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"g"}, result.Evict)
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
// Running: g, v. Request L.
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
// Only one set contains L, so pick it.
result, err = solver.Solve("L", []string{"g", "v"})
require.NoError(t, err)
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
assert.Equal(t, []string{"L"}, result.TargetSet)
// Running: g, v. Request sd.
// creative[g,sd]: evict v (cost 50). Total: 50.
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
// => pick creative[g,sd]
result, err = solver.Solve("sd", []string{"g", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"v"}, result.Evict)
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
// Running: q, v, e. Request g.
// standard[g,v]: evict q (1) + e (1). Total: 2.
// with_rerank[g,v,e]: evict q (1). Total: 1.
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
// => pick with_rerank[g,v,e]
result, err = solver.Solve("g", []string{"e", "q", "v"})
require.NoError(t, err)
assert.Equal(t, []string{"q"}, result.Evict)
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
}
-184
View File
@@ -1,184 +0,0 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
// isStreaming bool
startTime time.Time
}
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
// check for streaming response
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
writer.metricsRecorder.processStreamingResponse(writer.body)
} else {
writer.metricsRecorder.processNonStreamingResponse(writer.body)
}
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return false
}
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
+527 -18
View File
@@ -1,13 +1,66 @@
package proxy
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/mostlygeek/llama-swap/event"
"github.com/tidwall/gjson"
)
// zstdEncOptions are the shared zstd encoder options for maximum compression.
var zstdEncOptions = []zstd.EOption{
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
}
// zstdDecOptions are the shared zstd decoder options.
var zstdDecOptions = []zstd.DOption{}
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
var zstdEncPool = &sync.Pool{
New: func() interface{} {
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
return enc
},
}
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
var zstdDecPool = &sync.Pool{
New: func() interface{} {
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
return dec
},
}
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
// Returns compressed bytes and the original JSON byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
jsonBytes, err := json.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
enc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(enc)
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
}
// decompressCapture decompresses zstd-compressed JSON and returns it.
func decompressCapture(data []byte) ([]byte, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
return dec.DecodeAll(data, nil)
}
// TokenMetrics represents parsed token statistics from llama-server logs
type TokenMetrics struct {
ID int `json:"id"`
@@ -19,6 +72,16 @@ type TokenMetrics struct {
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
}
type ReqRespCapture struct {
ID int `json:"id"`
ReqPath string `json:"req_path"`
ReqHeaders map[string]string `json:"req_headers"`
ReqBody []byte `json:"req_body"`
RespHeaders map[string]string `json:"resp_headers"`
RespBody []byte `json:"resp_body"`
}
// TokenMetricsEvent represents a token metrics event
@@ -30,29 +93,39 @@ func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
}
// MetricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct {
// metricsMonitor parses llama-server output for token statistics
type metricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
maxMetrics int
nextID int
logger *LogMonitor
// capture fields
enableCaptures bool
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total compressed size in bytes
maxCaptureSize int // max bytes for captures (uncompressed)
}
func NewMetricsMonitor(config *Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
// capture buffer size in megabytes; 0 disables captures.
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
return &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0,
captures: make(map[int][]byte),
captureOrder: make([]int, 0),
captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024,
}
mp := &MetricsMonitor{
maxMetrics: maxMetrics,
}
return mp
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
// addMetrics adds a new metric to the collection and publishes an event.
// Returns the assigned metric ID.
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
mp.mu.Lock()
defer mp.mu.Unlock()
@@ -63,10 +136,88 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
}
event.Emit(TokenMetricsEvent{Metrics: metric})
return metric.ID
}
// GetMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
// addCapture adds a new capture to the buffer with size-based eviction.
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
if !mp.enableCaptures {
return
}
compressed, uncompressedBytes, err := compressCapture(&capture)
if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
return
}
captureSize := len(compressed)
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return
}
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
mp.mu.Lock()
defer mp.mu.Unlock()
// Evict oldest (FIFO) until room available for the compressed data
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists {
l := len(evicted)
mp.captureSize -= l
delete(mp.captures, oldestID)
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
}
}
mp.captures[capture.ID] = compressed
mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
}
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
mp.mu.RLock()
defer mp.mu.RUnlock()
data, exists := mp.captures[id]
return data, exists
}
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
// If decompress=false, returns the raw zstd-compressed bytes.
// Returns nil if the capture is not found.
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
mp.mu.RLock()
defer mp.mu.RUnlock()
data, exists := mp.captures[id]
if !exists {
return nil
}
if !decompress {
return data
}
decompressed, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return decompressed
}
// getMetrics returns a copy of the current metrics
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
mp.mu.RLock()
defer mp.mu.RUnlock()
@@ -75,9 +226,367 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
return result
}
// GetMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
// getMetricsJSON returns metrics as JSON
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
}
// wrapHandler wraps the proxy handler to extract token metrics
// if wrapHandler returns an error it is safe to assume that no
// data was sent to the client
func (mp *metricsMonitor) wrapHandler(
modelID string,
writer gin.ResponseWriter,
request *http.Request,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
// Capture request body and headers if captures enabled
var reqBody []byte
var reqHeaders map[string]string
if mp.enableCaptures {
if request.Body != nil {
var err error
reqBody, err = io.ReadAll(request.Body)
if err != nil {
return fmt.Errorf("failed to read request body for capture: %w", err)
}
request.Body.Close()
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
}
reqHeaders = make(map[string]string)
for key, values := range request.Header {
if len(values) > 0 {
reqHeaders[key] = values[0]
}
}
redactHeaders(reqHeaders)
}
recorder := newBodyCopier(writer)
// Filter Accept-Encoding to only include encodings we can decompress for metrics
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
}
if err := next(modelID, recorder, request); err != nil {
return err
}
// after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
return nil
}
// Initialize default metrics - these will always be recorded
tm := TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm)
return nil
}
// Decompress if needed
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
var err error
body, err = decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm)
return nil
}
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsed
}
} else {
if gjson.ValidBytes(body) {
parsed := gjson.ParseBytes(body)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
// extract timings for infill - response is an array, timings are in the last element
// see #463
if strings.HasPrefix(request.URL.Path, "/infill") {
if arr := parsed.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
}
if usage.Exists() || timings.Exists() {
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsedMetrics
}
}
} else {
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
}
}
// Build capture if enabled and determine if it will be stored
var capture *ReqRespCapture
if mp.enableCaptures {
respHeaders := make(map[string]string)
for key, values := range recorder.Header() {
if len(values) > 0 {
respHeaders[key] = values[0]
}
}
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
capture = &ReqRespCapture{
ReqPath: request.URL.Path,
ReqHeaders: reqHeaders,
ReqBody: reqBody,
RespHeaders: respHeaders,
RespBody: body,
}
compressed, _, err := compressCapture(capture)
if err == nil && len(compressed) <= mp.maxCaptureSize {
tm.HasCapture = true
}
}
metricID := mp.addMetrics(tm)
// Store capture if enabled
if capture != nil {
capture.ID = metricID
mp.addCapture(*capture)
}
return nil
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.
// Start from the end of the body and scan backwards for newlines
pos := len(body)
for pos > 0 {
// Find the previous newline (or start of body)
lineStart := bytes.LastIndexByte(body[:pos], '\n')
if lineStart == -1 {
lineStart = 0
} else {
lineStart++ // Move past the newline
}
line := bytes.TrimSpace(body[lineStart:pos])
pos = lineStart - 1 // Move position before the newline for next iteration
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
parsed := gjson.ParseBytes(data)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
// v1/responses format nests usage under response.usage
if !usage.Exists() {
usage = parsed.Get("response.usage")
}
if usage.Exists() || timings.Exists() {
return parseMetrics(modelID, start, usage, timings)
}
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
wallDurationMs := int(time.Since(start).Milliseconds())
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := wallDurationMs
if usage.Exists() {
if pt := usage.Get("prompt_tokens"); pt.Exists() {
// v1/chat/completions
inputTokens = int(pt.Int())
} else if it := usage.Get("input_tokens"); it.Exists() {
// v1/messages
inputTokens = int(it.Int())
}
if ct := usage.Get("completion_tokens"); ct.Exists() {
// v1/chat/completions
outputTokens = int(ct.Int())
} else if ot := usage.Get("output_tokens"); ot.Exists() {
outputTokens = int(ot.Int())
}
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
cachedTokens = int(ct.Int())
}
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(timings.Get("prompt_n").Int())
outputTokens = int(timings.Get("predicted_n").Int())
promptPerSecond = timings.Get("prompt_per_second").Float()
tokensPerSecond = timings.Get("predicted_per_second").Float()
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
if timingsDurationMs > durationMs {
durationMs = timingsDurationMs
}
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
return TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
}, nil
}
// decompressBody decompresses the body based on Content-Encoding header
func decompressBody(body []byte, encoding string) ([]byte, error) {
switch strings.ToLower(strings.TrimSpace(encoding)) {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
case "deflate":
reader := flate.NewReader(bytes.NewReader(body))
defer reader.Close()
return io.ReadAll(reader)
default:
return body, nil // Return as-is for unknown/no encoding
}
}
// responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
gin.ResponseWriter
body *bytes.Buffer
tee io.Writer
start time.Time
}
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
bodyBuffer := &bytes.Buffer{}
return &responseBodyCopier{
ResponseWriter: w,
body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer),
}
}
func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
}
// Single write operation that writes to both the response and buffer
return w.tee.Write(b)
}
func (w *responseBodyCopier) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseBodyCopier) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *responseBodyCopier) StartTime() time.Time {
return w.start
}
// sensitiveHeaders lists headers that should be redacted in captures
var sensitiveHeaders = map[string]bool{
"authorization": true,
"proxy-authorization": true,
"cookie": true,
"set-cookie": true,
"x-api-key": true,
}
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
func redactHeaders(headers map[string]string) {
for key := range headers {
if sensitiveHeaders[strings.ToLower(key)] {
headers[key] = "[REDACTED]"
}
}
}
// filterAcceptEncoding filters the Accept-Encoding header to only include
// encodings we can decompress (gzip, deflate). This respects the client's
// preferences while ensuring we can parse response bodies for metrics.
func filterAcceptEncoding(acceptEncoding string) string {
if acceptEncoding == "" {
return ""
}
supported := map[string]bool{"gzip": true, "deflate": true}
var filtered []string
for part := range strings.SplitSeq(acceptEncoding, ",") {
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
if supported[strings.ToLower(encoding)] {
filtered = append(filtered, strings.TrimSpace(part))
}
}
return strings.Join(filtered, ", ")
}
File diff suppressed because it is too large Load Diff
+143
View File
@@ -0,0 +1,143 @@
package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type peerProxyMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type PeerProxy struct {
peers config.PeerDictionaryConfig
proxyMap map[string]*peerProxyMember
}
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
proxyMap := make(map[string]*peerProxyMember)
// Sort peer IDs for consistent iteration order
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
for _, peerID := range peerIDs {
peer := peers[peerID]
// Create a transport with per-peer timeout configuration
peerTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
}
// Create reverse proxy for this peer
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
reverseProxy.Transport = peerTransport
// Wrap Director to set Host header for remote hosts (not localhost)
originalDirector := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
originalDirector(req)
// Ensure Host header matches target URL for remote proxying
req.Host = req.URL.Host
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerProxyMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
// Map each model to this peer's proxy
for _, modelID := range peer.Models {
if _, found := proxyMap[modelID]; found {
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
proxyMap[modelID] = pp
}
}
return &PeerProxy{
peers: peers,
proxyMap: proxyMap,
}, nil
}
func (p *PeerProxy) HasPeerModel(modelID string) bool {
_, found := p.proxyMap[modelID]
return found
}
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
pp, found := p.proxyMap[modelID]
if !found {
return config.Filters{}
}
// Get the peer config using the peerID
peer, found := p.peers[pp.peerID]
if !found {
return config.Filters{}
}
return peer.Filters
}
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
return p.peers
}
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
pp, found := p.proxyMap[model_id]
if !found {
return fmt.Errorf("no peer proxy found for model %s", model_id)
}
// Inject API key if configured for this peer
if pp.apiKey != "" {
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
request.Header.Set("x-api-key", pp.apiKey)
}
pp.reverseProxy.ServeHTTP(writer, request)
return nil
}
+311
View File
@@ -0,0 +1,311 @@
package proxy
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.NotNil(t, pm)
assert.Empty(t, pm.proxyMap)
}
func TestNewPeerProxy_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 2)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.False(t, pm.HasPeerModel("model-c"))
}
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 4)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.True(t, pm.HasPeerModel("model-c"))
assert.True(t, pm.HasPeerModel("model-d"))
}
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
// should be mapped, and a warning should be logged
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
// Should only have one entry for the duplicate model
assert.Len(t, pm.proxyMap, 1)
assert.True(t, pm.HasPeerModel("duplicate-model"))
}
func TestHasPeerModel(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
Models: []string{"existing-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.True(t, pm.HasPeerModel("existing-model"))
assert.False(t, pm.HasPeerModel("non-existing-model"))
}
func TestProxyRequest_ModelNotFound(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("non-existing-model", w, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
}
func TestProxyRequest_Success(t *testing.T) {
// Create a test server to act as the peer
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "response from peer", w.Body.String())
}
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
}
func TestProxyRequest_NoApiKey(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "", // No API key
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Empty(t, receivedAuthHeader)
}
func TestProxyRequest_HostHeaderSet(t *testing.T) {
// Create a test server that checks the Host header
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The Host header should be set to the target URL's host
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
}
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
// Create a test server that returns SSE content type
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The X-Accel-Buffering header should be set to "no" for SSE
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
}
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
proxyURL, _ := url.Parse("http://localhost:8080")
peers := config.PeerDictionaryConfig{
"test-peer": config.PeerConfig{
Proxy: "http://localhost:8080",
ProxyURL: proxyURL,
Models: []string{"model1"},
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 300,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
},
}
peerProxy, err := NewPeerProxy(peers, testLogger)
assert.NoError(t, err)
assert.NotNil(t, peerProxy)
assert.True(t, peerProxy.HasPeerModel("model1"))
// Verify the timeout values are actually applied to the transport
member, found := peerProxy.proxyMap["model1"]
require.True(t, found, "model1 should exist in proxyMap")
assert.NotNil(t, member.reverseProxy)
assert.NotNil(t, member.reverseProxy.Transport)
transport, ok := member.reverseProxy.Transport.(*http.Transport)
require.True(t, ok, "Transport should be *http.Transport")
// Verify all timeout values are correctly applied
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
// ForceAttemptHTTP2 should be enabled
assert.True(t, transport.ForceAttemptHTTP2)
}
+460 -72
View File
@@ -2,20 +2,23 @@ package proxy
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type ProcessState string
@@ -38,11 +41,13 @@ const (
)
type Process struct {
ID string
config ModelConfig
cmd *exec.Cmd
ID string
config config.ModelConfig
cmd *exec.Cmd
reverseProxy *httputil.ReverseProxy
// PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc
// closed when command exits
@@ -54,12 +59,14 @@ type Process struct {
healthCheckTimeout int
healthCheckLoopInterval time.Duration
lastRequestHandled time.Time
lastRequestHandledMutex sync.RWMutex
lastRequestHandled time.Time
stateMutex sync.RWMutex
state ProcessState
inFlightRequests sync.WaitGroup
inFlightRequests sync.WaitGroup
inFlightRequestsCount atomic.Int32
// used to block on multiple start() calls
waitStarting sync.WaitGroup
@@ -70,20 +77,60 @@ type Process struct {
// used for testing to override the default value
gracefulStopTimeout time.Duration
// used for testing to bypass subprocess and reverse proxy
testHandler http.Handler
// track the number of failed starts
failedStartCount int
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
}
// Setup the reverse proxy.
proxyURL, err := url.Parse(config.Proxy)
if err != nil {
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
}
var reverseProxy *httputil.ReverseProxy
if proxyURL != nil {
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
// Create custom transport with configured timeouts
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
}
reverseProxy.Transport = transport
reverseProxy.ModifyResponse = func(resp *http.Response) error {
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
}
return &Process{
ID: ID,
config: config,
cmd: nil,
reverseProxy: reverseProxy,
cancelUpstream: nil,
processLogger: processLogger,
proxyLogger: proxyLogger,
@@ -106,6 +153,20 @@ func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger
}
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
func (p *Process) setLastRequestHandled(t time.Time) {
p.lastRequestHandledMutex.Lock()
defer p.lastRequestHandledMutex.Unlock()
p.lastRequestHandled = t
}
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
func (p *Process) getLastRequestHandled() time.Time {
p.lastRequestHandledMutex.RLock()
defer p.lastRequestHandledMutex.RUnlock()
return p.lastRequestHandled
}
// custom error types for swapping state
var (
ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -129,6 +190,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
}
p.state = newState
// Atomically increment waitStarting when entering StateStarting
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
if newState == StateStarting {
p.waitStarting.Add(1)
}
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil
@@ -157,11 +225,63 @@ func (p *Process) CurrentState() ProcessState {
return p.state
}
// forceState forces the process state to the new state with mutex protection.
// This should only be used in exceptional cases where the normal state transition
// validation via swapState() cannot be used.
func (p *Process) forceState(newState ProcessState) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.state = newState
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
func (p *Process) start() error {
// test-only fast path: skip subprocess, health check, and TTL goroutine
if p.testHandler != nil {
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
if curState == StateStarting {
p.waitStarting.Wait()
curState = p.CurrentState()
if curState == StateReady {
return nil
}
return fmt.Errorf("process was already starting but wound up in state %v", curState)
}
return fmt.Errorf("process was in state %v when start() was called", curState)
}
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
defer p.waitStarting.Done()
// Mimic the real stop path: cancelUpstream transitions
// StateStopping -> StateStopped and closes cmdWaitChan,
// matching what waitForCmd does for real subprocesses.
ch := make(chan struct{})
p.cmdMutex.Lock()
p.cancelUpstream = func() {
if curState := p.CurrentState(); curState == StateStopping {
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
p.forceState(StateStopped)
}
} else {
p.forceState(StateStopped)
}
close(ch)
}
p.cmdWaitChan = ch
p.cmdMutex.Unlock()
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
}
p.failedStartCount = 0
return nil
}
if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing")
}
@@ -190,7 +310,7 @@ func (p *Process) start() error {
}
}
p.waitStarting.Add(1)
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
@@ -200,8 +320,12 @@ func (p *Process) start() error {
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout
setProcAttributes(p.cmd)
p.cmdMutex.Lock()
p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()
p.failedStartCount++ // this will be reset to zero when the process has successfully started
@@ -211,7 +335,7 @@ func (p *Process) start() error {
// Set process state to failed
if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state
p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr,
@@ -284,10 +408,12 @@ func (p *Process) start() error {
return
}
// wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
// skip the TTL check if there are inflight requests
if p.inFlightRequestsCount.Load() != 0 {
continue
}
if time.Since(p.lastRequestHandled) > maxDuration {
if time.Since(p.getLastRequestHandled()) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop()
return
@@ -306,7 +432,10 @@ func (p *Process) start() error {
// Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() {
// guard to prevent multiple goroutines from stopping
if !isValidTransition(p.CurrentState(), StateStopping) {
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
return
}
@@ -319,13 +448,17 @@ func (p *Process) Stop() {
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
// guard to prevent multiple goroutines from stopping the process
enterState := p.CurrentState()
if !isValidTransition(enterState, StateStopping) {
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
return
}
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
if curState, err := p.swapState(enterState, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
return
}
@@ -343,7 +476,7 @@ func (p *Process) Shutdown() {
p.stopCommand()
// just force it to this state since there is no recovery from shutdown
p.state = StateShutdown
p.forceState(StateShutdown)
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -352,15 +485,23 @@ func (p *Process) stopCommand() {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
// free the buffer in processLogger so the memory can be recovered
p.processLogger.Clear()
}()
if p.cancelUpstream == nil {
p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
p.cmdMutex.RUnlock()
if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return
}
p.cancelUpstream()
<-p.cmdWaitChan
cancelUpstream()
<-cmdWaitChan
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
@@ -398,6 +539,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if p.reverseProxy == nil {
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
return
}
requestBeginTime := time.Now()
var startDuration time.Duration
@@ -417,74 +564,82 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
}
p.inFlightRequests.Add(1)
p.inFlightRequestsCount.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.setLastRequestHandled(time.Now())
p.inFlightRequestsCount.Add(-1)
p.inFlightRequests.Done()
}()
// for #366
// - extract streaming param from request context, should have been set by proxymanager
var srw *statusResponseWriter
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
// start the process on demand
if p.CurrentState() != StateReady {
// start a goroutine to stream loading status messages into the response writer
// add a sync so the streaming client only runs when the goroutine has exited
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
// PR #417 (no support for anthropic v1/messages yet)
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
srw = newStatusResponseWriter(p, w)
go srw.statusUpdates(swapCtx)
} else {
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
}
beginStartTime := time.Now()
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusBadGateway)
cancelLoadCtx()
if srw != nil {
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
// before closing the connection. Without this, the connection would close before
// the goroutine can write its cleanup messages, causing incomplete SSE output.
srw.waitForCompletion(100 * time.Millisecond)
} else {
http.Error(w, errstr, http.StatusBadGateway)
}
return
}
startDuration = time.Since(beginStartTime)
}
proxyTo := p.config.Proxy
client := &http.Client{}
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header.Clone()
// should trigger srw to stop sending loading events ...
cancelLoadCtx()
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
if err == nil {
req.ContentLength = contentLength
}
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
w.Header().Set("X-Accel-Buffering", "no")
}
w.WriteHeader(resp.StatusCode)
// faster than io.Copy when streaming
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
return
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
// recover from http.ErrAbortHandler panics that can occur when the client
// disconnects before the response is sent
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
} else {
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
}
}
if err == io.EOF {
break
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}()
if srw != nil {
// Wait for the goroutine to finish writing its final messages
const completionTimeout = 1 * time.Second
if !srw.waitForCompletion(completionTimeout) {
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
}
}
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
} else if srw != nil {
p.reverseProxy.ServeHTTP(srw, r)
} else {
p.reverseProxy.ServeHTTP(w, r)
}
totalTime := time.Since(requestBeginTime)
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
p.ID, r.RequestURI, startDuration, totalTime)
@@ -518,13 +673,16 @@ func (p *Process) waitForCmd() {
case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped
p.forceState(StateStopped)
}
default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state
p.forceState(StateStopped) // force it to be in this state
}
p.cmdMutex.Lock()
close(p.cmdWaitChan)
p.cmdMutex.Unlock()
}
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
@@ -539,7 +697,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return err
@@ -550,6 +708,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
setProcAttributes(stopCmd)
stopCmd.Env = p.cmd.Env
if err := stopCmd.Run(); err != nil {
@@ -565,3 +724,232 @@ func (p *Process) cmdStopUpstreamProcess() error {
return nil
}
// Logger returns the logger for this process.
func (p *Process) Logger() *LogMonitor {
return p.processLogger
}
var loadingRemarks = []string{
"Still faster than your last standup meeting...",
"Reticulating splines...",
"Waking up the hamsters...",
"Teaching the model manners...",
"Convincing the GPU to participate...",
"Loading weights (they're heavy)...",
"Herding electrons...",
"Compiling excuses for the delay...",
"Downloading more RAM...",
"Asking the model nicely to boot up...",
"Bribing CUDA with cookies...",
"Still loading (blame VRAM)...",
"The model is fashionably late...",
"Warming up those tensors...",
"Making the neural net do push-ups...",
"Your patience is appreciated (really)...",
"Almost there (probably)...",
"Loading like it's 1999...",
"The model forgot where it put its keys...",
"Quantum tunneling through layers...",
"Negotiating with the PCIe bus...",
"Defrosting frozen parameters...",
"Teaching attention heads to focus...",
"Running the matrix (slowly)...",
"Untangling transformer blocks...",
"Calibrating the flux capacitor...",
"Spinning up the probability wheels...",
"Waiting for the GPU to wake from its nap...",
"Converting caffeine to compute...",
"Allocating virtual patience...",
"Performing arcane CUDA rituals...",
"The model is stuck in traffic...",
"Inflating embeddings...",
"Summoning computational demons...",
"Pleading with the OOM killer...",
"Calculating the meaning of life (still at 42)...",
"Training the training wheels...",
"Optimizing the optimizer...",
"Bootstrapping the bootstrapper...",
"Loading loading screen...",
"Processing processing logs...",
"Buffering buffer overflow jokes...",
"The model hit snooze...",
"Debugging the debugger...",
"Compiling the compiler...",
"Parsing the parser (meta)...",
"Tokenizing tokens...",
"Encoding the encoder...",
"Hashing hash browns...",
"Forking spoons (not forks)...",
"The model is contemplating existence...",
"Transcending dimensional barriers...",
"Invoking elder tensor gods...",
"Unfurling probability clouds...",
"Synchronizing parallel universes...",
"The GPU is having second thoughts...",
"Recalibrating reality matrices...",
"Time is an illusion, loading doubly so...",
"Convincing bits to flip themselves...",
"The model is reading its own documentation...",
}
type statusResponseWriter struct {
hasWritten bool
writer http.ResponseWriter
process *Process
wg sync.WaitGroup // Track goroutine completion
start time.Time
}
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
s := &statusResponseWriter{
writer: w,
process: p,
start: time.Now(),
}
s.Header().Set("Content-Type", "text/event-stream") // SSE
s.Header().Set("Cache-Control", "no-cache") // no-cache
s.Header().Set("Connection", "keep-alive") // keep-alive
s.WriteHeader(http.StatusOK) // send status code 200
s.sendLine("━━━━━")
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
return s
}
// statusUpdates sends status updates to the client while the model is loading
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
s.wg.Add(1)
defer s.wg.Done()
// Recover from panics caused by client disconnection
// Note: recover() only works within the same goroutine, so we need it here
defer func() {
if r := recover(); r != nil {
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
}
}()
defer func() {
duration := time.Since(s.start)
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
s.sendLine("━━━━━")
s.sendLine(" ")
}()
// Create a shuffled copy of loadingRemarks
remarks := make([]string, len(loadingRemarks))
copy(remarks, loadingRemarks)
rand.Shuffle(len(remarks), func(i, j int) {
remarks[i], remarks[j] = remarks[j], remarks[i]
})
ri := 0
// Pick a random duration to send a remark
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
lastRemarkTime := time.Now()
ticker := time.NewTicker(time.Second)
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if s.process.CurrentState() == StateReady {
return
}
// Check if it's time for a snarky remark
if time.Since(lastRemarkTime) >= nextRemarkIn {
remark := remarks[ri%len(remarks)]
ri++
s.sendLine(fmt.Sprintf("\n%s", remark))
lastRemarkTime = time.Now()
// Pick a new random duration for the next remark
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
} else {
s.sendData(".")
}
}
}
}
// waitForCompletion waits for the statusUpdates goroutine to finish
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
func (s *statusResponseWriter) sendLine(line string) {
s.sendData(line + "\n")
}
func (s *statusResponseWriter) sendData(data string) {
// Create the proper SSE JSON structure
type Delta struct {
ReasoningContent string `json:"reasoning_content"`
}
type Choice struct {
Delta Delta `json:"delta"`
}
type SSEMessage struct {
Choices []Choice `json:"choices"`
}
msg := SSEMessage{
Choices: []Choice{
{
Delta: Delta{
ReasoningContent: data,
},
},
},
}
jsonData, err := json.Marshal(msg)
if err != nil {
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
return
}
// Write SSE formatted data, panic if not able to write
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
if err != nil {
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
}
s.Flush()
}
func (s *statusResponseWriter) Header() http.Header {
return s.writer.Header()
}
func (s *statusResponseWriter) Write(data []byte) (int, error) {
return s.writer.Write(data)
}
func (s *statusResponseWriter) WriteHeader(statusCode int) {
if s.hasWritten {
return
}
s.hasWritten = true
s.writer.WriteHeader(statusCode)
s.Flush()
}
func (s *statusResponseWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
}
}
+137 -22
View File
@@ -2,6 +2,7 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -10,6 +11,7 @@ import (
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
)
@@ -90,7 +92,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
config := ModelConfig{
config := config.ModelConfig{
Cmd: "nonexistent-command",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
@@ -116,12 +118,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
}
expectedMessage := "I_sense_imminent_danger"
config := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
conf := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
conf.UnloadAfter = 3 // seconds
assert.Equal(t, 3, conf.UnloadAfter)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
defer process.Stop()
// this should take 4 seconds
@@ -158,12 +160,12 @@ func TestProcess_LowTTLValue(t *testing.T) {
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
conf := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
conf.UnloadAfter = 1 // second
assert.Equal(t, 1, conf.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
@@ -325,7 +327,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
config := config.ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
@@ -394,6 +396,10 @@ func TestProcess_StopImmediately(t *testing.T) {
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
if runtime.GOOS == "windows" {
t.Skip("skipping SIGTERM test on Windows ")
}
@@ -402,7 +408,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
binaryPath := getSimpleResponderPath()
port := getTestPort()
config := ModelConfig{
conf := config.ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
@@ -410,7 +416,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
defer process.Stop()
// reduce to make testing go faster
@@ -435,7 +441,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
if runtime.GOOS == "windows" {
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
} else {
assert.Contains(t, w.Body.String(), "unexpected EOF")
// Upstream may be killed mid-response.
// Assert an incomplete or partial response.
assert.NotEqual(t, "12345", w.Body.String())
}
close(waitChan)
@@ -450,15 +458,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
}
func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
conf := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
conf.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
conf.CmdStop = "kill -TERM ${PID}"
}
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
@@ -470,15 +478,15 @@ func TestProcess_StopCmd(t *testing.T) {
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
expectedMessage := "test_env_not_emptied"
config := getTestSimpleResponderConfig(expectedMessage)
conf := getTestSimpleResponderConfig(expectedMessage)
// ensure that the the default config does not blank out the inherited environment
configWEnv := config
configWEnv := conf
// ensure the additiona variables are appended to the process' environment
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
process1.start()
@@ -491,3 +499,110 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
}
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
// handled appropriately.
//
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
// can't copy the body. This can be caused by a client disconnecting before the full
// response is sent from some reason.
//
// bug: https://github.com/mostlygeek/llama-swap/issues/362
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
defer func() {
if r := recover(); r != nil {
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
}
}()
expectedMessage := "panic_test"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
defer process.Stop()
// Start the process
err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())
// Create a custom ResponseWriter that simulates a client disconnect
// by panicking when Write is called after headers are sent
panicWriter := &panicOnWriteResponseWriter{
ResponseRecorder: httptest.NewRecorder(),
shouldPanic: true,
}
// Make a request that will trigger the panic
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
// ProxyRequest should catch and handle this panic gracefully.
process.ProxyRequest(panicWriter, req)
// If we get here, the panic was properly recovered in ProxyRequest
// The process should still be in a ready state
assert.Equal(t, StateReady, process.CurrentState())
}
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
// to simulate a client disconnect after headers are sent
// used by: TestProcess_ReverseProxyPanicIsHandled
type panicOnWriteResponseWriter struct {
*httptest.ResponseRecorder
shouldPanic bool
headerWritten bool
}
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
w.headerWritten = true
w.ResponseRecorder.WriteHeader(statusCode)
}
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
if w.shouldPanic && w.headerWritten {
// Simulate the panic that httputil.ReverseProxy throws
panic(http.ErrAbortHandler)
}
return w.ResponseRecorder.Write(b)
}
func TestProcess_CustomTimeouts(t *testing.T) {
modelConfig := config.ModelConfig{
Cmd: "echo test",
Proxy: "http://localhost:8080",
CheckEndpoint: "/health",
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 120,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
}
debugLogger := NewLogMonitorWriter(io.Discard)
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
// Verify the process was created successfully
assert.NotNil(t, process)
assert.Equal(t, "test-model", process.ID)
assert.NotNil(t, process.reverseProxy)
assert.NotNil(t, process.reverseProxy.Transport)
// Verify it's using http.Transport (not some other type)
transport, ok := process.reverseProxy.Transport.(*http.Transport)
assert.True(t, ok, "Transport should be *http.Transport")
assert.NotNil(t, transport)
// Verify the timeouts are correctly applied
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
assert.True(t, transport.ForceAttemptHTTP2)
}
+12
View File
@@ -0,0 +1,12 @@
//go:build !windows
package proxy
import (
"os/exec"
)
// setProcAttributes sets platform-specific process attributes
func setProcAttributes(cmd *exec.Cmd) {
// No-op on Unix systems
}
+16
View File
@@ -0,0 +1,16 @@
//go:build windows
package proxy
import (
"os/exec"
"syscall"
)
// setProcAttributes sets platform-specific process attributes
func setProcAttributes(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
}
}
+49 -3
View File
@@ -5,12 +5,14 @@ import (
"net/http"
"slices"
"sync"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type ProcessGroup struct {
sync.Mutex
config Config
config config.Config
id string
swap bool
exclusive bool
@@ -22,9 +24,25 @@ type ProcessGroup struct {
// map of current processes
processes map[string]*Process
lastUsedProcess string
// inflight tracks fast-path requests (requests for the already-selected
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
// and Done() on completion; a concurrent swap request calls inflight.Wait()
// under pg.Lock before stopping the current process. Without this tracking,
// a fast-path request that has released pg.Lock but has not yet called
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
// killed mid-request.
inflight sync.WaitGroup
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
// the fast path after pg.Lock is released but before the request is
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
// request at the exact race window to deterministically reproduce the
// fast-path vs swap race.
testDelayFastPath func()
}
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
@@ -44,7 +62,8 @@ func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstream
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
@@ -61,6 +80,13 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
pg.Lock()
if pg.lastUsedProcess != modelID {
// Wait for in-flight fast-path requests to drain before stopping
// the previous process. Without this, a fast-path request that has
// released pg.Lock but has not yet incremented
// Process.inFlightRequests races with Stop() and can be killed
// mid-request.
pg.inflight.Wait()
// is there something already running?
if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop()
@@ -75,7 +101,16 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
pg.Unlock()
return nil
}
// Fast path: register this request in inflight before releasing
// pg.Lock so a concurrent swap will wait for it to complete.
pg.inflight.Add(1)
defer pg.inflight.Done()
pg.Unlock()
if pg.testDelayFastPath != nil {
pg.testDelayFastPath()
}
}
pg.processes[modelID].ProxyRequest(writer, request)
@@ -86,6 +121,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
if pg.HasMember(modelName) {
return pg.processes[modelName], true
}
return nil, false
}
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
pg.Lock()
@@ -113,6 +155,10 @@ func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
if strategy != StopImmediately {
pg.inflight.Wait()
}
if len(pg.processes) == 0 {
return
}
+238 -7
View File
@@ -4,22 +4,26 @@ import (
"bytes"
"net/http"
"net/http/httptest"
"runtime"
"sync"
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
@@ -34,7 +38,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
})
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}
@@ -48,9 +52,13 @@ func TestProcessGroup_HasMember(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
if testing.Short() {
t.Skip("skipping slow test")
}
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
// use the same listening so if a model is already running, it will fail
// this is a way to test that swap isolation is working
// properly when there are parallel requests made at the
@@ -61,7 +69,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
"model4": getTestSimpleResponderConfigPort("model4", 9832),
"model5": getTestSimpleResponderConfigPort("model5", 9832),
},
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2", "model3", "model4", "model5"},
@@ -90,6 +98,229 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
wg.Wait()
}
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
// request cannot stop the current process while a fast-path request (for the
// already-selected model) is in flight. Without ProcessGroup-level inflight
// tracking, a fast-path request that has released pg.Lock but has not yet
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
// process is killed mid-request.
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
cfg := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2"},
},
},
})
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
defer pg.StopProcesses(StopImmediately)
// Bypass real subprocesses so the test is fast and deterministic.
pg.processes["model1"].testHandler = newTestHandler("model1")
pg.processes["model2"].testHandler = newTestHandler("model2")
// Prime: run a request through model1 via the swap path so that
// lastUsedProcess == "model1" and subsequent model1 requests take the
// fast path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
// Fast-path hook: signal arrival at the race window, then wait for
// release. This parks R2 deterministically at the point where pg.Lock
// has been released but Process.inFlightRequests has not yet been
// incremented — the exact window the race exploits.
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
pg.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
// R2: fast-path request for model1. Will pause at the test hook.
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
}()
// Deterministically wait for R2 to reach the race window.
<-r2Reached
// R3: swap request for model2. Must wait for R2 to finish before touching
// model1, otherwise model1 gets killed mid-request.
r3Done := make(chan struct{})
w3 := httptest.NewRecorder()
go func() {
defer close(r3Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
}()
// Spin until R3 has acquired pg.Lock and entered the swap critical
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
// still holding the lock, so TryLock keeps failing.
for pg.TryLock() {
pg.Unlock()
runtime.Gosched()
}
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
// nothing changes, so we wait the full window. In the buggy code, R3
// will Stop() model1 and start serving via model2 within microseconds —
// we exit early once the mutation is observable.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if pg.processes["model1"].CurrentState() != StateReady ||
pg.processes["model2"].CurrentState() != StateStopped {
break
}
done := false
select {
case <-r3Done:
done = true
default:
}
if done {
break
}
runtime.Gosched()
}
// Invariant: R3 must be blocked while R2 is still in flight.
select {
case <-r3Done:
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
default:
}
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
"model1 must stay Ready while a fast-path request is in flight")
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
"model2 must not be started until R2 finishes and model1 is swapped out")
// Release R2 and let both requests finish.
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
assert.Equal(t, http.StatusOK, w3.Code)
assert.Contains(t, w3.Body.String(), "model2")
}
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
// process while a fast-path ProxyRequest is in the [pg.Unlock,
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
// StopProcesses, the external caller bypasses the inflight guard and kills the
// process mid-request.
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
cfg := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2"},
},
},
})
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
defer pg.StopProcesses(StopImmediately)
pg.processes["model1"].testHandler = newTestHandler("model1")
pg.processes["model2"].testHandler = newTestHandler("model2")
// Prime: model1 is active so subsequent model1 requests take the fast path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
// Park a fast-path request at the race window.
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
pg.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
}()
<-r2Reached
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
// the group while a fast-path request is in flight.
r3Done := make(chan struct{})
go func() {
defer close(r3Done)
pg.StopProcesses(StopWaitForInflightRequest)
}()
// Spin until StopProcesses has acquired pg.Lock.
for pg.TryLock() {
pg.Unlock()
runtime.Gosched()
}
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
// and model1 stays Ready. In the buggy code it proceeds immediately and
// kills model1.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if pg.processes["model1"].CurrentState() != StateReady {
break
}
select {
case <-r3Done:
goto done
default:
}
runtime.Gosched()
}
done:
select {
case <-r3Done:
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
default:
}
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
"model1 must stay Ready while a fast-path request is in flight")
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest)
+633 -157
View File
File diff suppressed because it is too large Load Diff
+120 -34
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"github.com/gin-gonic/gin"
@@ -13,21 +14,26 @@ import (
)
type Model struct {
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"`
Aliases []string `json:"aliases,omitempty"`
}
func addApiHandlers(pm *ProxyManager) {
// Add API endpoints for React to consume
apiGroup := pm.ginEngine.Group("/api")
// Protected with API key authentication
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
{
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
apiGroup.GET("/version", pm.apiGetVersion)
apiGroup.GET("/captures/:id", pm.apiGetCapture)
}
}
@@ -49,27 +55,28 @@ func (pm *ProxyManager) getModelStatus() []Model {
// Iterate over sorted keys
for _, modelID := range modelIDs {
// Get process state
processGroup := pm.findGroupByModelName(modelID)
state := "unknown"
if processGroup != nil {
process := processGroup.processes[modelID]
if process != nil {
var stateStr string
switch process.CurrentState() {
case StateReady:
stateStr = "ready"
case StateStarting:
stateStr = "starting"
case StateStopping:
stateStr = "stopping"
case StateShutdown:
stateStr = "shutdown"
case StateStopped:
stateStr = "stopped"
default:
stateStr = "unknown"
}
state = stateStr
var process *Process
if pm.matrix != nil {
process, _ = pm.matrix.GetProcess(modelID)
} else {
processGroup := pm.findGroupByModelName(modelID)
if processGroup != nil {
process = processGroup.processes[modelID]
}
}
if process != nil {
switch process.CurrentState() {
case StateReady:
state = "ready"
case StateStarting:
state = "starting"
case StateStopping:
state = "stopping"
case StateShutdown:
state = "shutdown"
case StateStopped:
state = "stopped"
}
}
models = append(models, Model{
@@ -78,9 +85,22 @@ func (pm *ProxyManager) getModelStatus() []Model {
Description: pm.config.Models[modelID].Description,
State: state,
Unlisted: pm.config.Models[modelID].Unlisted,
Aliases: pm.config.Models[modelID].Aliases,
})
}
// Iterate over the peer models
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
for _, modelID := range peer.Models {
models = append(models, Model{
Id: modelID,
PeerID: peerID,
})
}
}
}
return models
}
@@ -90,6 +110,7 @@ const (
msgTypeModelStatus messageType = "modelStatus"
msgTypeLogData messageType = "logData"
msgTypeMetrics messageType = "metrics"
msgTypeInFlight messageType = "inflight"
)
type messageEnvelope struct {
@@ -149,6 +170,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
}
sendInFlight := func(total int) {
jsonData, err := json.Marshal(gin.H{"total": total})
if err == nil {
select {
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
case <-ctx.Done():
return
default:
}
}
}
/**
* Send updated models list
*/
@@ -176,11 +209,19 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
sendMetrics([]TokenMetrics{e.Metrics})
})()
/**
* Send in-flight request stats related to token stats "Waiting: N" count.
*/
defer event.On(func(e InFlightRequestsEvent) {
sendInFlight(e.Total)
})()
// send initial batch of data
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
sendMetrics(pm.metricsMonitor.GetMetrics())
sendMetrics(pm.metricsMonitor.getMetrics())
sendInFlight(pm.inFlightCounter.Current())
for {
select {
@@ -198,7 +239,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
jsonData, err := pm.metricsMonitor.getMetricsJSON()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return
@@ -214,16 +255,61 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
return
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
var stopErr error
if pm.matrix != nil {
stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
} else {
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
}
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
}
if stopErr != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
return
}
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
c.JSON(http.StatusOK, map[string]string{
"version": pm.version,
"commit": pm.commit,
"build_date": pm.buildDate,
})
}
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
return
}
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
data, exists := pm.metricsMonitor.getCompressedBytes(id)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
return
}
c.Header("Vary", "Accept-Encoding")
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
if hasZstd {
c.Header("Content-Encoding", "zstd")
c.Data(http.StatusOK, "application/json", data)
} else {
c.String(http.StatusOK, "OK")
decompressed, err := decompressCapture(data)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
return
}
c.Data(http.StatusOK, "application/json", decompressed)
}
}
+20 -13
View File
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// prevent nginx from buffering streamed logs
c.Header("X-Accel-Buffering", "no")
logMonitorId := c.Param("logMonitorID")
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
switch logMonitorId {
case "":
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return pm.muxLogger, nil
case "proxy":
return pm.proxyLogger, nil
case "upstream":
return pm.upstreamLogger, nil
default:
// search for a models specific logger using findModelInPath
// to handle model names with slashes (e.g., "author/model")
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
for _, group := range pm.processGroups {
if process, found := group.GetMember(name); found {
return process.Logger(), nil
}
}
}
return logger, nil
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
}
}
+1021 -332
View File
File diff suppressed because it is too large Load Diff
+81
View File
@@ -0,0 +1,81 @@
package proxy
import (
"net/http"
"strings"
)
// selectEncoding chooses the best encoding based on Accept-Encoding header
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
func selectEncoding(acceptEncoding string) (encoding, ext string) {
if acceptEncoding == "" {
return "", ""
}
for _, part := range strings.Split(acceptEncoding, ",") {
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
if enc == "br" {
return "br", ".br"
}
}
for _, part := range strings.Split(acceptEncoding, ",") {
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
if enc == "gzip" {
return "gzip", ".gz"
}
}
return "", ""
}
// ServeCompressedFile serves a file with compression support.
// It checks for pre-compressed versions and serves them with proper headers.
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
// Try to serve compressed version if client supports it
if encoding != "" {
if cf, err := fs.Open(name + ext); err == nil {
defer cf.Close()
// Verify it's a regular file (not a directory)
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
// Set the content encoding header
w.Header().Set("Content-Encoding", encoding)
w.Header().Add("Vary", "Accept-Encoding")
// Get original file info for content type detection
origFile, err := fs.Open(name)
if err == nil {
origFile.Close()
}
// Serve the compressed file
http.ServeContent(w, r, name, stat.ModTime(), cf)
return
}
}
}
// Fall back to serving the uncompressed file
file, err := fs.Open(name)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if stat.IsDir() {
http.Error(w, "is a directory", http.StatusForbidden)
return
}
http.ServeContent(w, r, name, stat.ModTime(), file)
}
+283
View File
@@ -0,0 +1,283 @@
package proxy
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"testing/fstest"
"time"
)
func TestServeCompressedFile_Brotli(t *testing.T) {
// Create test content
content := []byte("This is test content that should be compressed with brotli")
brContent := []byte("fake-brotli-compressed-data")
// Create a test filesystem
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
"test.js.br": {Data: brContent, ModTime: time.Now()},
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
req.Header.Set("Accept-Encoding", "br, gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check that brotli is used (preferred over gzip)
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
}
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
}
if !bytes.Equal(body, brContent) {
t.Errorf("Expected brotli content, got %s", string(body))
}
}
func TestServeCompressedFile_Gzip(t *testing.T) {
// Create test content
content := []byte("This is test content that should be compressed with gzip")
gzContent := []byte("fake-gzip-compressed-data")
// Create a test filesystem without brotli
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
}
if !bytes.Equal(body, gzContent) {
t.Errorf("Expected gzip content, got %s", string(body))
}
}
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
// Create test content
content := []byte("This is uncompressed test content")
// Create a test filesystem without compressed versions
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
req.Header.Set("Accept-Encoding", "br, gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Should not have Content-Encoding header since we're serving uncompressed
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
}
if !bytes.Equal(body, content) {
t.Errorf("Expected original content, got %s", string(body))
}
}
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
// Create test content
content := []byte("This is test content")
// Create a test filesystem with compressed versions
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
// No Accept-Encoding header
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Should serve uncompressed content
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
}
if !bytes.Equal(body, content) {
t.Errorf("Expected original content, got %s", string(body))
}
}
func TestServeCompressedFile_NotFound(t *testing.T) {
mapFS := fstest.MapFS{}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "nonexistent.js")
resp := w.Result()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", resp.StatusCode)
}
}
func TestSelectEncoding(t *testing.T) {
tests := []struct {
acceptEncoding string
wantEncoding string
wantExt string
}{
{"br, gzip", "br", ".br"},
{"gzip, deflate", "gzip", ".gz"},
{"gzip", "gzip", ".gz"},
{"br", "br", ".br"},
{"", "", ""},
{"deflate", "", ""},
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
{"browser", "", ""},
{"compress, deflate", "", ""},
}
for _, tt := range tests {
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
}
}
}
// Test with actual pre-compressed files from ui_dist
func TestServeCompressedFile_RealFiles(t *testing.T) {
// Check if ui_dist exists
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
t.Skip("ui_dist not found, skipping real file test")
}
// Find a .js or .css file that has compressed versions
entries, err := os.ReadDir("./ui_dist/assets")
if err != nil {
t.Skipf("Could not read ui_dist/assets: %v", err)
}
var testFile string
for _, entry := range entries {
name := entry.Name()
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
// Check if compressed versions exist
base := strings.TrimSuffix(name, ".js")
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
testFile = "assets/" + name
break
}
}
}
if testFile == "" {
t.Skip("No suitable test file found with compressed versions")
}
fs := http.FS(os.DirFS("./ui_dist"))
// Test brotli
t.Run("brotli", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
req.Header.Set("Accept-Encoding", "br")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, testFile)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
}
})
// Test gzip
t.Run("gzip", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, testFile)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
}
// Verify it's valid gzip
reader, err := gzip.NewReader(resp.Body)
if err != nil {
t.Errorf("Expected valid gzip content: %v", err)
return
}
defer reader.Close()
// Just read to verify it's valid
_, err = io.Copy(io.Discard, reader)
if err != nil {
t.Errorf("Failed to decompress gzip: %v", err)
}
})
}
+2
View File
@@ -0,0 +1,2 @@
node_modules
.vite
+1
View File
@@ -0,0 +1 @@
legacy-peer-deps=true

Some files were not shown because too many files have changed in this diff Show More