Compare commits

..

88 Commits

Author SHA1 Message Date
Benson Wong a9d840ffd7 proxy,proxy/config: restore timeouts to pre PR 619 (#648)
Reset the default ResponseHeader timeout to 0 (no timeout) which was set
to 60 seconds in PR #619.

Fixes #647
2026-04-11 20:42:13 -07:00
Benson Wong 7b2b82777f docker/unified: derive rootless image from root container (#644)
Build the root image once, then derive the rootless variant from it
using a small inline Dockerfile that adds the non-root user and chowns
the writable directories. This halves the number of CI jobs (4 → 2) and
eliminates the redundant full CUDA compilation for the rootless variant.

- remove RUN_UID build arg from build-image.sh
- derive rootless image inline after root build completes
- collapse variant matrix out of unified-docker.yml
- push both root and rootless tags in a single CI job

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 22:59:54 -07:00
Benson Wong d87f0ce2c5 docker/unified: publish rootless image variant (#630) 2026-04-07 03:05:53 -07:00
Leoy 06bc6a614c proxy: preserve wall-clock duration in metrics (#629)
Keep request duration from being underreported when upstream timings
only cover part of the full request lifecycle.

- compare wall-clock and upstream timing durations
- keep token and throughput values from timings
- add regression coverage for underreported timings

fixes #602
2026-04-07 01:52:41 -07:00
Ron M a37b4866d8 proxy: add configurable HTTP timeouts for models and peers (#619)
Add configurable HTTP timeout settings to both models and peers to support installations that requires longer timeouts than the current hardcoded defaults.

Closes #618
2026-04-06 19:30:27 +08:00
Benson Wong 981910d734 ci: validate config.example.yaml against config-schema.json (#627)
Extend the existing config-schema workflow to also validate
config.example.yaml against config-schema.json using check-jsonschema.

- add config.example.yaml to PR and push path triggers
- install check-jsonschema via pip
- run validation of config.example.yaml against schema

https://claude.ai/code/session_01Y1oqwE6mwNs9UTJgZRgXtG

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-05 15:17:57 +08:00
Benson Wong a185efe37e docker: make CMAKE_CUDA_ARCHITECTURES configurable via build arg (#625)
Expose CMAKE_CUDA_ARCHITECTURES as a Docker build ARG so users can
customize CUDA architectures via --build-arg without editing the
Dockerfile.

- convert hardcoded ENV to ARG with default, feeding into ENV
- replace silent fallback defaults (:-) in scripts with :? guards
  to fail fast if the env var is missing
- add usage example to Dockerfile header

Follow up to: #624

https://claude.ai/code/session_01EWiUe7jNABX7Uz95dUGJqK

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-04 08:49:59 +08:00
Benson Wong 1dd1aadf93 docker/unified: add ik_llama.cpp to CUDA container (#620) 2026-04-03 15:16:30 +08:00
Benson Wong 955900972a add /sdapi to list of supported endpoints 2026-04-01 12:01:38 +08:00
Benson Wong c2c8cfaf81 docker/unified: build llama.cpp with static libraries (#616) 2026-04-01 03:38:07 +08:00
Benson Wong 1e440770ea ci: fix matrix exclude for scheduled docker workflow (#610) 2026-03-29 20:04:28 +09:00
Benson Wong c794273c83 docker/unified,.github: fix unified build (#606) 2026-03-27 10:31:12 +09:00
dependabot[bot] 6574a52cbb build(deps): bump picomatch from 4.0.3 to 4.0.4 in /ui-svelte (#605) 2026-03-26 22:28:24 +09:00
Benson Wong 8fabc75634 docker/unified: vulkan build fixes (#600)
multiple fixes to vulkan build: 

- use ubuntu 26.04 to be compatible with AMD 395+ (Strix halo) hardware
- add home directory in container 
- fix stable-diffusion install to actually enable vulkan

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-25 23:26:13 +09:00
Benson Wong e5e7391b6d .github,docker/unified: include vulkan build (#599)
Update docker/unified scripts to support building both cuda and vulkan unified images.
2026-03-25 06:58:28 +09:00
Benson Wong 2c282dccad .github,docker/unified: improve caching and fix bugs (#598)
- set up a GHA scheduled job to build the container nightly 
- enabling pushing a llama-swap:unified and a llama-swap:unified-Y-M-D
image to ghcr.io
- tidy up Dockerfile to use a non-root user and llama-swap as an entry
point
2026-03-23 22:24:40 +09:00
Benson Wong 916d13f5bd .github/workflows,docker/unified: add cuda based unified container (#597)
Add Docker build scripts for a unified cuda docker container with llama-server, stable-diffusion.cpp, whisper.cpp.
2026-03-22 21:11:54 +09:00
Benson Wong a3725e7d09 Update go.mod to 1.26.1 (#593) 2026-03-20 16:09:58 +09:00
Benson Wong 15bd55d3a9 proxy, ui-svelte: add /sdapi/v1 endpoint support (#587)
Add proxy routes for stable-diffusion.cpp's /sdapi/v1/txt2img,
/sdapi/v1/img2img, and /sdapi/v1/loras endpoints. POST endpoints
use proxyInferenceHandler (model in JSON body), GET /loras uses
proxyGETModelHandler (model in query param).

Update the image playground with a dual-mode UI supporting both
OpenAI and SDAPI backends. In SDAPI mode, loras are fetched first
to prime the server-side cache, and all txt2img parameters are
exposed (negative prompt, steps, cfg_scale, seed, batch_size,
clip_skip, sampler, scheduler, lora selection with multipliers).

- Add 3 sdapi route registrations in proxymanager.go
- Add sdApi.ts client with generateSdImage and fetchSdLoras
- Add SDAPI types (SdApiTxt2ImgRequest, SdApiResponse, etc.)
- Add /sdapi to vite dev proxy config
- Add backend tests for sdapi routing
- Support batch image display in gallery grid

https://claude.ai/code/session_0186MGX6NXdHVBTv2KH45fqn

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-19 22:08:31 +09:00
Benson Wong c3c258a55d proxy: fix metrics capture for v1/responses (#586)
properly parse anthropic compatible usage data from streaming responses.

closes: #577
2026-03-13 16:50:12 -07:00
Benson Wong 29a38fde0d ui-svelte: upgrade to vite 8 (#585)
Upgrade vite and related dependencies to take advantage of Vite 8's
improved build times via Rolldown and Oxc.

- vite: ^6.3.5 → ^8.0.0
- @sveltejs/vite-plugin-svelte: ^5.0.3 → ^7.0.0
- svelte: ^5.19.0 → ^5.46.4
- vite-plugin-compression2: ^2.4.0 → ^2.5.1
- vitest: ^4.0.18 → ^4.1.0

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-13 08:45:59 -07:00
tesuri d569681daa Change model sorting to natural order (#582)
Use natural sorting for model names.

Previously the model list was sorted lexicographically, which resulted
in unintuitive ordering when numbers were included in the name.

Example:

Before
qwen3.5:2B
qwen3.5:35B-3AB
qwen3.5:9B

After
qwen3.5:2B
qwen3.5:9B
qwen3.5:35B-3AB

This change sorts models using natural order so numeric parts are
compared numerically.
2026-03-12 07:49:34 -07:00
Benson Wong 24efdb76b1 config: add macro support for name and description fields (#578)
Extend macro substitution to the name and description fields of
ModelConfig, matching the behavior already present for cmd, proxy,
checkEndpoint, and filters.

- substitute global/model macros (including MODEL_ID) in name and
description
- substitute PORT macro in name and description when allocated
- validate no unknown macros remain in name and description after
substitution
- add tests for macro substitution, MODEL_ID, and unknown macro error
2026-03-10 08:27:05 -07:00
Benson Wong cc77139ff8 proxy,proxy/config: add global TTL feature (#554)
Add a new configuration parameter globalTTL that all models will
inherit. The default value is 0 which matches the currently
functionality to never automatically unload a model.

The model.ttl's default has changed to -1, which means use the global
TTL value. Any model.ttl >=0 is now value with 0 meaning never unload.
This allows a model to override a globalTTL > 0 and be configured to
never unload.

Fixes #459
Closes #512
2026-03-01 21:02:12 -08:00
Benson Wong 390a35bf93 ui-svelte: add copy button to markdown code blocks (#537)
Add a copy-to-clipboard button that appears on hover for each code block
rendered in the chat interface assistant messages.

- Svelte action `codeBlockCopy` injects a button into every `<pre>`
element
- MutationObserver reattaches buttons as streaming content arrives
- Button shows a check icon for 2 seconds after a successful copy
- Uses clipboard API with execCommand fallback for non-secure contexts
- CSS hides button by default and reveals it on pre:hover

https://claude.ai/code/session_01PTA5ao5YQuFAS6a9juLeZW

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-01 09:48:56 -08:00
pdscomp 181f71ca11 .github,docker: add cuda13 architecture support (#551)
Add `cuda13` as a supported build architecture, targeting the
`ghcr.io/ggml-org/llama.cpp:server-cuda13` upstream base image.

The `server-cuda13` image ships with CUDA 13 libraries, providing
improved performance on recent NVIDIA hardware compared to the existing
`server-cuda` (CUDA 12) image. Users with newer GPUs (e.g., RTX
50-series) benefit from reduced model load latency and higher token
throughput.

- Add `cuda13` to the allowed architectures list in
`docker/build-container.sh`
- Add `cuda13` to the CI matrix in `.github/workflows/containers.yml` so
the container is built and pushed automatically
2026-03-01 09:37:08 -08:00
Benson Wong 49546e2cf2 ui: fix text size svg 2026-02-27 23:47:52 -08:00
Benson Wong 2c078964f4 Update README with additional images
Added new images for model loading and real-time log streaming sections.
2026-02-27 23:45:40 -08:00
Benson Wong 175bb36fb1 Revise README description for clarity and detail
Updated description to clarify compatibility and usage.
2026-02-27 23:42:40 -08:00
Benson Wong aedb640471 Enhance web UI section in README
Updated README to enhance the description of the web interface and added details about features like token metrics, request inspection, model management, and real-time log streaming.
2026-02-27 23:40:31 -08:00
Benson Wong 2f377f6dc6 ui: add OGG audio format support to transcription playground (#544) 2026-02-26 19:48:19 -08:00
Benson Wong 64e4c79fc3 ui: add Rerank tab to playground (#536)
Add a new Rerank tab to the playground that lets users test /v1/rerank
endpoints. Supports a visual table editor and a JSON editor mode that
stay in sync when toggling between them.

- add rerankApi.ts with typed wrapper for /v1/rerank
- add RerankInterface.svelte with query input, sortable document table,
color-coded scores, auto-add row, cancel/clear, and token usage
- add rerankLoading store to playgroundActivity derived store
- register Rerank tab in Playground.svelte

Updates #481
2026-02-21 21:59:14 -08:00
Benson Wong 19fb5f35e9 proxy: implement setParamsByID filter (#535)
Add setParamsByID filter that applies different request parameters based
on the requested model ID, enabling per-alias behaviour for a single
loaded model.

- add SetParamsByID field to Filters struct and SanitizedSetParamsByID
method
- substitute ${MODEL_ID} and other macros in setParamsByID keys and
values
- validate no unknown macros remain in keys or values after substitution
- apply setParamsByID in proxyInferenceHandler after setParams (can
override it)
- update config-schema.json with setParamsByID definition
- update UI to show aliases and make them selectable in the Playground

closes #534
2026-02-19 22:21:10 -08:00
Benson Wong b45102bde8 ui: smart auto-scroll in LogPanel (#530)
Pause auto-scroll when the user scrolls up to review logs, and resume
when they scroll back to the bottom.

- add `userScrolledUp` state variable
- add `handleScroll` to detect scroll position with 40px threshold
- guard the auto-scroll effect with `!userScrolledUp`

closes #529
2026-02-18 19:47:37 -08:00
Brian Mendonca 1688bdd1e9 proxy, ui: add pending requests count to the main dashboard (#516)
add a real time counter of pending (inflight) requests to the UI.
2026-02-16 09:41:15 -08:00
Benson Wong d33d51fa75 .coderabbit.yaml,AGENTS.md: small tweaks 2026-02-15 21:31:30 -08:00
Benson Wong e3bf065574 ui: persist playground state across route navigation (#525)
- Keep Playground component mounted when navigating away, preserving
streaming/generating state
- Add animated gradient effect on Playground nav link when activity is
in progress
2026-02-15 21:30:52 -08:00
Benson Wong 3e52144058 ui-svelte: incremental rendering of chat messages in the Playground (#520)
add incremental rendering to Playground > Chat
2026-02-15 11:00:44 -08:00
Benson Wong d5e52d7d00 build: disable provenance attestations in container builds (#523)
## Summary
- Add `--provenance=false` to docker build commands in
`build-container.sh`
- BuildKit attestation manifests are stored as untagged images in GHCR,
and the `delete-untagged-containers` cleanup job deletes them, breaking
the manifest list and causing `manifest unknown` errors on pull
- ref: https://github.com/actions/delete-package-versions/issues/162
2026-02-14 10:23:08 -08:00
Benson Wong 17e5263a76 .github/workflows: fix expired token in publishing images (#522)
Fixes: #517
2026-02-14 10:06:05 -08:00
Benson Wong 8d6d949ec3 proxy: support timings for /infill from llama-server (#510)
fixes: #463
2026-02-07 17:16:27 -08:00
Benson Wong b5fde8eb6d proxy,ui-svelte: add request/response capturing (#508)
Add saving request and response headers and bodies that go through
llama-swap in memory.

- captureBuffer added to configuration. Captures are enabled by default.
- 5MB of memory is allocated for req/response captures in a ring buffer.
Setting captureBuffer to 0 will disable captures.
- UI elements to view captured data added to Activity page. Includes
some
QOL features like json formatting and recombining SSE chat streams
- capture saving is done at the byte level and has minimal impact on
llama-swap performance

Fixes #464 
Ref #503
2026-02-07 15:40:01 -08:00
Nuno 7eef5defb8 docs: add stable-diffusion.cpp references (#506)
Signed-off-by: rare-magma <rare-magma@posteo.eu>
2026-02-04 20:20:39 -08:00
Benson Wong bc01e6f539 build: add stable-diffusion server to musa and vulkan container images (#504)
Add sd-server from stable-diffusion.cpp docker image for 
vulkan and musa containers.

closes #450
2026-02-01 16:17:26 -08:00
Benson Wong 0462e3dc3f Reorganize UI controls and improve form interactions (#500)
Reorganizes control placement in the playground interfaces and
improves form interactions for better UX, particularly on mobile
devices.

## Key Changes

- **AudioInterface & ImageInterface**: Moved "Clear" buttons from the
top control bar into the action button group below the form inputs for
better visual hierarchy and logical grouping
- **ImageInterface**: 
- Added prompt clearing to the `clearImage()` function so the input
field is reset when clearing generated images
- Updated Clear button disabled state to also check if prompt is empty,
allowing users to clear an empty prompt
- Added responsive flex styling (`flex-1 md:flex-none`) to the Clear
button for better mobile layout
- **ExpandableTextarea**: 
- Imported `untrack` from Svelte to properly handle reactive
dependencies
- Wrapped `expandedValue.length` in `untrack()` to prevent unnecessary
reactivity when setting cursor position
- Improved button visibility on mobile by changing opacity from
`opacity-0` to `opacity-60` with `md:opacity-0` breakpoint, making the
expand button more discoverable on touch devices

## Implementation Details
The `untrack()` usage in ExpandableTextarea ensures that reading the
text length doesn't create a reactive dependency, preventing potential
infinite loops while still allowing the effect to run when `isExpanded`
changes.
2026-02-01 15:18:22 -08:00
Benson Wong 7b20fc011b Add path filters to CI workflows and create UI test workflow (#501)
* .github/workflows: add UI tests and path-filter Go CI

Add ui-tests.yml workflow to run svelte type checking and vitest
on push/PR to main when ui-svelte/ files change.

- Add path filters to go-ci.yml and go-ci-windows.yml to skip
  Go tests when only non-backend files change
- Filter on **/*.go, go.mod, go.sum, and Makefile

https://claude.ai/code/session_01E6acq54D8JjuE7pczxPGT7

* ui-svelte: remove unused declarations in SpeechInterface

Remove unused `generatedText` state and `clearAudio` function
that caused svelte-check errors.

https://claude.ai/code/session_01E6acq54D8JjuE7pczxPGT7

* .github/workflows: update Node.js to v24

Node 23 is end-of-life; bump to 24 in ui-tests.yml and release.yml.

https://claude.ai/code/session_01E6acq54D8JjuE7pczxPGT7

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-01 15:11:49 -08:00
Benson Wong 20738f3623 proxy,ui-svelte: replace old UI with svelte+playground
Replace the legacy React UI with the new Svelte-based one. Introduce a Playground in the UI to quickly test out text, image, text to speech and speech to text models behind llama-swap. 

Key Changes

New Svelte UI (ui-svelte/)

  - Multi-tab Playground with Chat, Image Generation, Audio Transcription, and Speech interfaces
  - Chat: message editing/regeneration, markdown rendering with LaTeX math support, image attachments, code syntax highlighting
  - Image: size selector, download/fullscreen viewing
  - Audio: transcription with peer support
  - Speech: voice caching with manual refresh, download button
  - Responsive mobile layout with collapsible navigation
  - XSS fixes and accessibility improvements

Proxy Improvements

  - Add gzip/brotli compression for UI static assets (proxy/ui_compress.go)
  - Add GET /v1/audio/voices?model={model} endpoint for voice listing
  - Add peer support for /v1/audio/transcriptions
2026-01-31 22:49:13 -08:00
Benson Wong cdea7d16bd proxy/config: skip env macros in YAML comment lines (#496)
Fix a bug where ${env.macro_not_exist} in comments would trigger a non-substituted macro error. 

fixes #495
2026-01-30 20:10:29 -08:00
Benson Wong 5de387dbf9 ui: fix node-tar vulnerability 2026-01-28 21:40:18 -08:00
Benson Wong 6f8e7ccb57 .github/workflows: switch release.yml to build ui-svelte 2026-01-28 21:39:10 -08:00
Benson Wong 4384315b44 ui-svelte: add Svelte port of React UI (#487)
Trying out svelte for the UI. The port was done by Claude Code on the iOS app w/ Opus 4.5. 

---

* ui: add Svelte port of React UI

Port the React-based UI to Svelte 5 with the following changes:

- Create new ui-svelte directory with complete Svelte 5 implementation
- Use Svelte stores instead of React contexts for state management
- Implement custom ResizablePanels component to replace react-resizable-panels
- Port all pages: LogViewer, Models, Activity
- Port all components: Header, ConnectionStatus, LogPanel, ModelsPanel, etc.
- Use svelte-spa-router for client-side routing
- Same build output directory (proxy/ui_dist) and base path (/ui/)
- Tailwind CSS 4 with same theme configuration

https://claude.ai/code/session_01F3xXLYsd62gePVSFv7aboP

* ui-svelte: simplify state management

- Remove redundant state syncing pattern in LogPanel and ModelsPanel
- Use store values directly with $ syntax instead of manual subscriptions
- Consolidate duplicate title sync logic in App.svelte
- Use existing syncTitleToDocument() from theme.ts

https://claude.ai/code/session_01F3xXLYsd62gePVSFv7aboP

* ui-svelte: use idiomatic Svelte 5 patterns

- Use $effect for document side effects (theme, title) instead of
  store subscriptions
- Use class: directive for active nav links in Header
- Remove SSR guards (unnecessary for client-only SPA)
- Remove leaked subscription in syncThemeToDocument
- Simplify theme.ts by removing sync functions

https://claude.ai/code/session_01F3xXLYsd62gePVSFv7aboP

* ui-svelte: fix build warnings and improve accessibility

Fix Svelte build warnings and add proper accessibility support
to interactive components.

- add aria-labels to buttons for screen readers
- implement keyboard navigation for resizable separator
- suppress intentional state initialization warnings
- update Makefile to use ui-svelte build directory
- add peer:true to package-lock.json dependencies

* ui-svelte: reorganize navigation and add log view toggle

Make Models the default landing page and add view mode toggle
to the Logs page with persistent state.

- set Models as default route at /
- move Logs to /logs route
- reorder navigation: Models, Activity, Logs
- add view toggle with three modes: Panels, Proxy only, Upstream only
- fix horizontal overflow with width constraints
2026-01-28 21:37:29 -08:00
Benson Wong 6439ab1515 ui: add peer:true in package-lock.json 2026-01-22 08:43:36 -08:00
dependabot[bot] f94226122c build(deps-dev): bump tar from 7.5.3 to 7.5.6 in /ui (#477)
Bumps [tar](https://github.com/isaacs/node-tar) from 7.5.3 to 7.5.6.
- [Release notes](https://github.com/isaacs/node-tar/releases)
- [Changelog](https://github.com/isaacs/node-tar/blob/main/CHANGELOG.md)
- [Commits](https://github.com/isaacs/node-tar/compare/v7.5.3...v7.5.6)

---
updated-dependencies:
- dependency-name: tar
  dependency-version: 7.5.6
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-21 22:55:02 -08:00
Ryan Voots 7493618fdc Add count_tokens api proxying (#476) 2026-01-20 09:34:42 -08:00
Benson Wong 205efd40a1 proxy: extend /running endpoint with additional process data (#474)
Extend the /running endpoint to return more details about running
processes beyond just model and state.

- add cmd field to show the command being executed
- add proxy field to show the proxy URL
- add ttl (UnloadAfter) for automatic unloading configuration
- add name and description for model metadata
- update tests to verify new fields are returned correctly

fixes #471
2026-01-19 17:37:00 -08:00
Benson Wong 14207f8492 ui: npm security update 2026-01-18 21:56:32 -08:00
Benson Wong 4e850c2834 config: refactor macro substitution in configuration (#470)
This commit simplifies substitution of environment variables into the configuration. There was a lot of repetitive code substituting ${env.VAR_NAME} into different fields after the configuration was parsed into a config.Config. This refactor uses a string substitution of env vars into the YAML config before it is fully parsed. This eliminates a lot of logic while maintaining backwards compatibility.
2026-01-18 21:52:34 -08:00
Benson Wong 75fced579e config: support macros in peer apiKey and filters (#469)
* config: support environment variable macros in peer apiKeys

Add ${env.VAR_NAME} substitution for peer apiKey fields, consistent
with existing env macro support for model fields and global apiKeys.

- Add env macro substitution for peers.{name}.apiKey in LoadConfigFromReader
- Add tests for peer apiKey env substitution
- Update config.example.yaml to show env macro usage

* config: support macros in peer apiKey and filters

Extend macro substitution to peer configuration fields:
- peers.{name}.apiKey supports both global macros and env macros
- peers.{name}.filters.stripParams supports both macro types
- peers.{name}.filters.setParams supports both macro types

Also renamed validateMetadataForUnknownMacros to validateNestedForUnknownMacros
for reuse across model metadata and peer filters validation.
2026-01-16 23:10:50 -08:00
Benson Wong b73f367f22 config-schema.json,config.example.yaml: Update examples and schema 2026-01-16 22:43:25 -08:00
Benson Wong 8f2137c72b config: support environment variable macros in apiKeys (#467)
Add substituteEnvMacros support for apiKeys configuration field,
allowing API keys to be loaded from environment variables using
the ${env.VAR_NAME} syntax.

- Apply env macro substitution before validation
- Add tests for env macro substitution in apiKeys
2026-01-16 22:41:14 -08:00
Benson Wong 124007cc98 config: add environment variable macros (#466)
* config: add environment variable macros

Add support for ${env.VAR_NAME} syntax to pull values from system
environment variables during config loading.

- env macros processed before regular macros (allows macros to reference env vars)
- works in cmd, cmdStop, proxy, checkEndpoint, filters.stripParams, metadata
- returns error if env var is not set
- add comprehensive tests

fixes #462

* docs: add env macro example to config.example.yaml
2026-01-16 22:25:20 -08:00
Benson Wong eb5bfff0b0 proxy: unify filtering for local models and peers
This unifies the filtering capabilities for models and peers

- stripParams: removes params in the request
- setParams: sets params in the request

fixes #453
2026-01-15 18:59:43 -08:00
Benson Wong 3edb180c08 ci: free up disk space before ROCm container build (#460) 2026-01-14 22:03:42 -08:00
Benson Wong 66d555e625 Improve container build reliability (#457)
* docker: add .env usage in build-container.sh
* .github,docker: add rocm, improve logging
* .github,CLAUDE.md: fix workflow and update guidelines

Update containers workflow to only push images when triggered
manually or on schedule, not on workflow file changes.

- add push trigger for workflow file changes in containers.yml
- update push condition to skip on regular push events
- update CLAUDE.md commit message guidelines

* docker: remove comma in build-container.sh

* .github,docker: improve container build workflow

Add pagination support for fetching llama.cpp tags and improve debugging.

- add build-container.sh to workflow trigger paths
- implement fetch_llama_tag() with pagination support
- replace .env with local testing instructions
- add DEBUG_ABORT_BUILD flag for testing
2026-01-10 22:14:33 -08:00
Benson Wong 4f863fd9fc CLAUDE.md: tweak instructions 2026-01-09 21:42:06 -08:00
Benson Wong 267c030457 ui: update react-router-dom to 7.12.0 (#456)
Update react-router-dom from 7.6.2 to 7.12.0 to address security vulnerability.

- Updated dependency in package.json
- Regenerated package-lock.json
- Verified build passes successfully
- Confirmed 0 vulnerabilities with npm audit

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-08 16:13:09 -08:00
Benson Wong c19309fe7e CLAUDE.md: small instruction tweaks 2026-01-07 21:34:23 -08:00
Benson Wong 4413881b2d proxy: actually add /v1/responses endpoint (#449)
ref: #448
2026-01-01 13:35:45 -08:00
Benson Wong 8df5e8563b proxy: add /v1/responses and /v1/audio/voices endpoints (#448)
Updates #433
Fixes #442 #226
2026-01-01 12:52:12 -08:00
Benson Wong 7931212d3e proxy: add v1/images/edits API endpoint (#447)
Updates #433
2026-01-01 12:43:06 -08:00
Benson Wong 3dc36032fb proxy: skip very slow tests in -short test mode (#446)
* proxy: skip very slow tests in -short test mode
* CLAUDE.md: update testing instructions
2025-12-31 14:08:56 -08:00
Benson Wong addb98646f proxy: add support for basic authorization (#445)
Fixes #444 where the UI with api keys did not work. The choice to use
http basic authorization is for simple, automatic browser support. No
changes to the UI were necessary. Just use an API key as the password,
no user name is required.
2025-12-31 13:42:35 -08:00
Benson Wong 37d74efc2d proxy: add /v1/images/generations (#443)
Add support for the /v1/images/generations endpoint

Updates #433
Closes #191
2025-12-30 21:04:58 -08:00
Benson Wong 22e098ac8b Add Peer Model Support (#438)
This PR allows a single llama-swap to be the central proxy for models served by other inference servers. The peer servers can be another llama-swap or any API that supports the /v1/* inference endpoint.

Updates: #433, #299
Closes: #296
2025-12-27 20:18:06 -08:00
Benson Wong 9864f9f517 .coderabbit.yaml: disable annoying features 2025-12-23 23:53:06 -08:00
Benson Wong 53b32f3601 proxy: add API key support (#436)
Add configuration support for api keys that are enforced by llama-swap. Keys are stripped before sending them to upstream servers. 

Updates: #433, #50 and #251
2025-12-23 23:39:33 -08:00
Benson Wong 565c44766d config,proxy: add new configuration logToStdout (#432)
The new logToStdout option controls what is logged to stdout. The
default has been changed to just the proxy logs, which contain swap and
http request logs.

There are four supported settings: none, proxy, upstream, both. The
"both" setting is the legacy setting where everything was spewed to
stdout.
2025-12-21 22:23:31 -08:00
Benson Wong e6a9e210ba proxy: fix path bug in /logs/stream/{model_id} (#431)
A {model_id} containing a forward slash trips up gin's path param
parsing. This updates /logs/stream to work like /upstream where the
model_id is built up in parts and searched for in the configuration.

Updates #421
2025-12-21 21:47:14 -08:00
Benson Wong d3f329f924 proxy: Improve logging performance and allow separate log streaming (#421)
Replace container/ring.Ring with a custom circularBuffer that uses a
single contiguous []byte slice. This fixes the original implementation
which created 10,240 ring elements instead of 10KB of storage.

GetHistory is now 139x faster (145μs → 1μs) and uses 117x less memory
(1.2MB → 10KB). Allocations reduced from 2 to 1 per write operation.

Create a LogMonitor per proxy.Process, replacing the usage
of a shared one. The buffer in LogMonitor is lazy allocated on the first
call to Write and freed when the Process is stopped. This reduces
unnecessary memory usage when a model is not active.

The /logs/stream/{model_id} endpoint was added to stream logs from a
specific process.
2025-12-18 21:49:25 -08:00
Benson Wong 98879b38c1 docker: add /app to $PATH (#424)
Make it so llama-server can be called directly instead of with the full
path at /app/llama-server.

Fixes #423
Ref: #233
2025-12-06 22:58:29 -08:00
Benson Wong 7b3b0f5eae move header images around [skip ci] 2025-12-02 19:40:42 -08:00
Benson Wong 021ccceef1 README: update hero image 2025-12-02 19:37:03 -08:00
Benson Wong f03871c50a Update README.md
- add supported anthropic API 
- add example for docker hot reload support
2025-12-02 19:03:01 -08:00
Ryan Steed dc00d17abe docs: add documentation for non-root container images and security considerations (#416)
* docs: add documentation for non-root container images and security considerations
* docs: move container security section to dedicated file and update README links
2025-12-02 08:52:26 -08:00
Benson Wong dea98733c3 proxy: extract metrics for v1/messages (#419) 2025-11-29 23:51:20 -08:00
Benson Wong bccce5fa19 go.mod,ui/package-lock.json: dependency and security updates (#418) 2025-11-29 22:27:22 -08:00
Benson Wong c968da1b73 proxy: add support for anthropic v1/messages api (#417)
* proxy: add support for anthropic v1/messages api
* proxy: restrict loading message to /v1/chat/completions
2025-11-29 22:09:07 -08:00
Ryan Steed a883d68d4f feat: Add support for custom llama.cpp base image and forked llama-swap repositories (#396)
* feat: Add support for custom llama.cpp base image and forked llama-swap repositories

- Introduce BASE_LLAMACPP_IMAGE env var to customize llama.cpp base image
- Introduce LS_REPO env var to customize llama-swap source
- Use GITHUB_REPOSITORY env var to automatically detect forked repos
- Update container tagging to use dynamic repo paths
- Pass build args for BASE_IMAGE and LS_REPO to Containerfile
- Enable flexible release downloads from forked repositories

* chore: quote entire curl options, appease coderabbitai
2025-11-29 20:59:15 -08:00
147 changed files with 17323 additions and 6225 deletions
+8 -1
View File
@@ -4,12 +4,19 @@ early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
high_level_summary: false
poem: false
review_status: true
collapse_walkthrough: false
sequence_diagrams: false
finishing_touches:
docstrings:
enabled: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true
issue_enrichment:
planning:
enabled: false
+15
View File
@@ -4,11 +4,15 @@ on:
pull_request:
paths:
- "config-schema.json"
- "config.example.yaml"
- ".github/workflows/config-schema.yml"
push:
branches:
- main
paths:
- "config-schema.json"
- "config.example.yaml"
- ".github/workflows/config-schema.yml"
workflow_dispatch:
@@ -39,3 +43,14 @@ jobs:
fi
echo "✓ config-schema.json is valid"
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Install check-jsonschema
run: pip install check-jsonschema
- name: Validate config.example.yaml against schema
run: check-jsonschema --schemafile config-schema.json config.example.yaml
+29 -2
View File
@@ -10,17 +10,44 @@ on:
# Allows manual triggering of the workflow
workflow_dispatch:
# Run on workflow file changes (without pushing)
push:
paths:
- '.github/workflows/containers.yml'
- 'docker/build-container.sh'
- 'docker/*.Containerfile'
# grant permissions on GITHUB_TOKEN to publish packages
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
permissions:
contents: read
packages: write
id-token: write
jobs:
build-and-push:
runs-on: ubuntu-latest
strategy:
matrix:
platform: [intel, cuda, vulkan, cpu, musa]
platform: [intel, cuda, cuda13, vulkan, cpu, musa, rocm]
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Free up disk space
if: matrix.platform == 'rocm'
run: |
echo "Before cleanup:"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
echo "After cleanup:"
df -h
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
@@ -31,7 +58,7 @@ jobs:
- name: Run build-container
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: ./docker/build-container.sh ${{ matrix.platform }} true
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
# see: https://github.com/actions/delete-package-versions/issues/74
+18 -2
View File
@@ -3,9 +3,25 @@ name: Windows CI
on:
push:
branches: [ "main" ]
# only run when backend source changes
# cmd/ is excluded because it contains utilities without tests
paths:
- '**/*.go'
- '!cmd/**'
- 'go.mod'
- 'go.sum'
- 'Makefile'
- '.github/workflows/go-ci-windows.yml'
pull_request:
branches: [ "main" ]
paths:
- '**/*.go'
- '!cmd/**'
- 'go.mod'
- 'go.sum'
- 'Makefile'
- '.github/workflows/go-ci-windows.yml'
# Allows manual triggering of the workflow
workflow_dispatch:
@@ -28,7 +44,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
@@ -43,7 +59,7 @@ jobs:
uses: actions/cache/save@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
- name: Test all
shell: bash
+19 -3
View File
@@ -3,9 +3,25 @@ name: Linux CI
on:
push:
branches: [ "main" ]
# only run when backend source changes
# cmd/ is excluded because it contains utilities without tests
paths:
- '**/*.go'
- '!cmd/**'
- 'go.mod'
- 'go.sum'
- 'Makefile'
- '.github/workflows/go-ci.yml'
pull_request:
branches: [ "main" ]
paths:
- '**/*.go'
- '!cmd/**'
- 'go.mod'
- 'go.sum'
- 'Makefile'
- '.github/workflows/go-ci.yml'
# Allows manual triggering of the workflow
workflow_dispatch:
@@ -20,7 +36,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23'
go-version-file: go.mod
# Only run in this linux based runner
- name: Check Formatting
@@ -35,7 +51,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: ./build
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
# necessary for testing proxy/Process swapping
- name: Create simple-responder
@@ -51,4 +67,4 @@ jobs:
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
- name: Test all
run: make test-all
run: make test-all
+11 -16
View File
@@ -3,13 +3,13 @@ name: goreleaser
on:
push:
tags:
- '*'
- "*"
# Allows manual triggering of the workflow
workflow_dispatch:
inputs:
tag:
description: 'Tag version to release (e.g. v144)'
description: "Tag version to release (e.g. v144)"
required: true
permissions:
@@ -19,35 +19,30 @@ jobs:
goreleaser:
runs-on: ubuntu-latest
steps:
-
name: Checkout
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
-
name: Set up Go
- name: Set up Go
uses: actions/setup-go@v5
-
name: Set up Node.js
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '23'
-
name: Install dependencies and build UI
node-version: "24"
- name: Install dependencies and build UI
run: |
cd ui
cd ui-svelte
npm ci
npm run build
-
name: Run GoReleaser
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
# either 'goreleaser' (default) or 'goreleaser-pro'
distribution: goreleaser
# 'latest', 'nightly', or a semver
version: '~> v2'
version: "~> v2"
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -76,4 +71,4 @@ jobs:
"release": {
"tag_name": "${{ steps.tag.outputs.tag }}"
}
}
}
+42
View File
@@ -0,0 +1,42 @@
name: UI Tests
on:
push:
branches: [ "main" ]
paths:
- 'ui-svelte/**'
- '.github/workflows/ui-tests.yml'
pull_request:
branches: [ "main" ]
paths:
- 'ui-svelte/**'
- '.github/workflows/ui-tests.yml'
workflow_dispatch:
jobs:
run-tests:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ui-svelte
steps:
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '24'
cache: 'npm'
cache-dependency-path: ui-svelte/package-lock.json
- name: Install dependencies
run: npm ci
- name: Type check
run: npm run check
- name: Run tests
run: npm test
+131
View File
@@ -0,0 +1,131 @@
name: Build Unified Docker Image
on:
schedule:
- cron: "37 5 * * *"
workflow_dispatch:
inputs:
llama_cpp_ref:
description: "llama.cpp commit hash, tag, or branch"
required: false
default: "master"
whisper_ref:
description: "whisper.cpp commit hash, tag, or branch"
required: false
default: "master"
sd_ref:
description: "stable-diffusion.cpp commit hash, tag, or branch"
required: false
default: "master"
ik_llama_ref:
description: "ik_llama.cpp commit hash, tag, or branch (CUDA only)"
required: false
default: "main"
llama_swap_version:
description: "llama-swap version (e.g. v198, latest, main)"
required: false
default: "main"
build_cuda:
description: "Build CUDA image"
type: boolean
required: false
default: true
build_vulkan:
description: "Build Vulkan image"
type: boolean
required: false
default: true
permissions:
contents: read
packages: write
jobs:
setup:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
run: |
backends=()
# schedule uses defaults (build both); workflow_dispatch respects inputs
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_cuda }}" == "true" ]]; then
backends+=("cuda")
fi
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_vulkan }}" == "true" ]]; then
backends+=("vulkan")
fi
matrix=$(printf '%s\n' "${backends[@]}" | jq -R . | jq -sc .)
echo "matrix=$matrix" >> $GITHUB_OUTPUT
build:
needs: setup
if: ${{ needs.setup.outputs.matrix != '[]' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Free up disk space
run: |
echo "Before cleanup:"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
echo "After cleanup:"
df -h
# On GitHub Actions runners, create a fresh builder.
# When running locally under act, skip this and reuse the existing
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
- name: Set up Docker Buildx
if: ${{ !env.ACT }}
uses: docker/setup-buildx-action@v3
- name: Log in to GitHub Container Registry
if: ${{ !env.ACT }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build unified Docker image (${{ matrix.backend }})
env:
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
SD_REF: ${{ inputs.sd_ref || 'master' }}
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}
# When running under act, use the local builder that has warm ccache.
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
# created by setup-buildx-action above.
BUILDX_BUILDER: ${{ env.ACT == 'true' && 'llama-swap-builder' || '' }}
run: |
chmod +x docker/unified/build-image.sh
docker/unified/build-image.sh --${{ matrix.backend }}
- name: Push to GitHub Container Registry
if: ${{ !env.ACT }}
run: |
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
DATE_TAG=$(date -u +%Y-%m-%d)
docker push "${BASE_TAG}"
docker tag "${BASE_TAG}" "${BASE_TAG}-${DATE_TAG}"
docker push "${BASE_TAG}-${DATE_TAG}"
ROOTLESS_TAG="${BASE_TAG}-rootless"
docker push "${ROOTLESS_TAG}"
docker tag "${ROOTLESS_TAG}" "${ROOTLESS_TAG}-${DATE_TAG}"
docker push "${ROOTLESS_TAG}-${DATE_TAG}"
+51
View File
@@ -0,0 +1,51 @@
## Project Description:
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
## Tech stack
- golang
- typescript, vite and svelt5 for UI (located in ui/)
## Workflow Tasks
- when summarizing changes only include details that require further action
- just say "Done." when there is no further action
- use the github CLI `gh` to create pull requests and work with github
- Rules for creating pull requests:
- keep them short and focused on changes.
- never include a test plan
- write the summary using the same style rules as commit message
## Testing
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests.
### Commit message example format:
```
proxy: add new feature
Add new feature that implements functionality X and Y.
- key change 1
- key change 2
- key change 3
fixes #123
```
## Code Reviews
- use three levels High, Medium, Low severity
- label each discovered issue with a label like H1, M2, L3 respectively
- High severity are must fix issues (security, race conditions, critical bugs)
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
- Low severity are nice to have changes and nits
- Include a suggestion with each discovered item
- Limit your code review to three items with the highest priority first
- Double check your discovered items and recommended remediations
+1 -43
View File
@@ -1,43 +1 @@
# Project: llama-swap
## Project Description:
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
## Tech stack
- golang
- typescript, vite and react for UI (ui/)
## Testing
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
## Workflow Tasks
### Plan Improvements
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
When the user asks to improve a plan follow these guidelines for expanding and improving it.
- Identify any inconsistencies.
- Expand plans out to be detailed specification of requirements and changes to be made.
- Plans should have at least these sections:
- Title - very short, describes changes
- Overview: A more detailed summary of goal and outcomes desired
- Design Requirements: Detailed descriptions of what needs to be done
- Testing Plan: Tests to be implemented
- Checklist: A detailed list of changes to be made
Look for "plan expansion" as explicit instructions to improve a plan.
### Implementation of plans
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
## General Rules
- when summarizing changes only include details that require further action (action items)
- when there are no action items, just say "Done."
@AGENTS.md
+3 -3
View File
@@ -36,11 +36,11 @@ test-all: proxy/ui_dist/placeholder.txt
go test -race -count=1 ./proxy/...
ui/node_modules:
cd ui && npm install
cd ui-svelte && npm install
# build react UI
ui: ui/node_modules
cd ui && npm run build
cd ui-svelte && npm run build
# Build OSX binary
mac: ui
@@ -51,7 +51,7 @@ mac: ui
linux: ui
@echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary
windows: ui
+56 -16
View File
@@ -1,11 +1,11 @@
![llama-swap header image](header2.png)
![llama-swap header image](docs/assets/hero3.webp)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
# llama-swap
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
@@ -13,18 +13,29 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
- ✅ On-demand model switching
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
- future proof, upgrade your inference servers at any time.
- ✅ OpenAI API supported endpoints:
- `v1/completions`
- `v1/chat/completions`
- `v1/responses`
- `v1/embeddings`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- `v1/audio/voices`
- `v1/images/generations`
- `v1/images/edits`
- ✅ Anthropic API supported endpoints:
- `v1/messages`
- `v1/messages/count_tokens`
- ✅ llama-server (llama.cpp) supported endpoints
- `v1/rerank`, `v1/reranking`, `/rerank`
- `/infill` - for code infilling
- `/completion` - for completion endpoint
- ✅ SDAPI via [stable-diffusion.cpp's server](https://github.com/leejet/stable-diffusion.cpp/tree/master/examples/server)
- `/sdapi/v1/txt2img`
- `/sdapi/v1/img2img`
- `/sdapi/v1/loras` - requires `model` in request body to fetch the correct loras
- ✅ llama-swap API
- `/ui` - web UI
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
@@ -32,6 +43,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring
- `/health` - just returns "OK"
- ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl`
@@ -40,14 +52,27 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
### Web UI
llama-swap includes a real time web interface for monitoring logs and controlling models:
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
View detailed token metrics:
<img width="1111" height="515" alt="image" src="https://github.com/user-attachments/assets/64bfb280-d7a3-4126-971a-a128fd40410c" />
Inspect request and responses:
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
Manually load and unload models:
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
The Activity Page shows recent requests:
Real time log streaming:
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
## Installation
@@ -61,7 +86,8 @@ llama-swap can be installed in multiple ways
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
@@ -71,6 +97,14 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda
# configuration hot reload supported with a
# directory volume mount
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
-v /path/to/config:/config \
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
```
<details>
@@ -89,6 +123,9 @@ docker pull ghcr.io/mostlygeek/llama-swap:musa
# tagged llama-swap, platform and llama-server version images
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
# non-root cuda
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
```
</details>
@@ -191,23 +228,26 @@ As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. H
## Monitoring Logs on the CLI
```shell
```sh
# sends up to the last 10KB of logs
curl http://host/logs'
$ curl http://host/logs
# streams combined logs
curl -Ns 'http://host/logs/stream'
curl -Ns http://host/logs/stream
# just llama-swap's logs
curl -Ns 'http://host/logs/stream/proxy'
# stream llama-swap's proxy status logs
curl -Ns http://host/logs/stream/proxy
# just upstream's logs
curl -Ns 'http://host/logs/stream/upstream'
# stream logs from upstream processes that llama-swap loads
curl -Ns http://host/logs/stream/upstream
# stream logs only from a specific model
curl -Ns http://host/logs/stream/{model_id}
# stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time'
# skips history and just streams new log entries
# appending ?no-history will disable sending buffered history first
curl -Ns 'http://host/logs/stream?no-history'
```
@@ -0,0 +1,85 @@
# Replace ring.Ring with Efficient Circular Byte Buffer
## Overview
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
## Current Issues
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
4. Linked list structure has poor cache locality and pointer overhead
## Design Requirements
### New CircularBuffer Type
Create a simple circular byte buffer with:
- Single pre-allocated `[]byte` of fixed capacity (10KB)
- `head` and `size` integers to track write position and data length
- No per-write allocations
### API Requirements
The new buffer must support:
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
### Implementation Details
```go
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
```
**Write logic:**
- If `len(p) >= capacity`: just keep the last `capacity` bytes
- Otherwise: write bytes at `head`, wrapping around if needed
- Update `head` and `size` accordingly
- Data is copied into the internal buffer (not stored by reference)
**GetHistory logic:**
- Calculate start position: `(head - size + cap) % cap`
- If not wrapped: single slice copy
- If wrapped: two copies (end of buffer + beginning)
- Returns a **new slice** (copy), not a view into internal buffer
### Immutability Guarantees (must preserve)
Per existing tests:
1. Modifying input `[]byte` after `Write()` must not affect stored data
2. `GetHistory()` returns independent copy - modifications don't affect buffer
## Files to Modify
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
## Testing Plan
Existing tests in `logMonitor_test.go` should continue to pass:
- `TestLogMonitor` - Basic write/read and subscriber notification
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
- `TestWrite_LogTimeFormat` - Timestamp formatting
Add new tests:
- Test buffer wrap-around behavior
- Test large writes that exceed buffer capacity
- Test exact capacity boundary conditions
## Checklist
- [ ] Create `circularBuffer` struct in `logMonitor.go`
- [ ] Implement `Write()` method for circular buffer
- [ ] Implement `GetHistory()` method for circular buffer
- [ ] Update `LogMonitor` struct to use new buffer
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
- [ ] Update `LogMonitor.Write()` to use new buffer
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
- [ ] Remove `"container/ring"` import
- [ ] Run `make test-dev` to verify existing tests pass
- [ ] Add wrap-around test case
- [ ] Run `make test-all` for final validation
+42
View File
@@ -210,6 +210,11 @@ func main() {
})
})
r.GET("/v1/audio/voices", func(c *gin.Context) {
model := c.Query("model")
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
})
r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo")
delay := c.Query("delay")
@@ -269,6 +274,43 @@ func main() {
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
})
// SD API endpoints
r.POST("/sdapi/v1/txt2img", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
c.JSON(http.StatusOK, gin.H{
"model": modelName,
"images": []string{},
})
})
r.POST("/sdapi/v1/img2img", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
c.JSON(http.StatusOK, gin.H{
"model": modelName,
"images": []string{},
})
})
r.GET("/sdapi/v1/loras", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"loras": []string{},
})
})
address := "127.0.0.1:" + *port // Address with the specified port
srv := &http.Server{
+187 -5
View File
@@ -39,6 +39,49 @@
},
"default": {},
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
},
"timeouts": {
"type": "object",
"properties": {
"connect": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP connection timeout in seconds. Set to 0 to disable."
},
"keepalive": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP keepalive timeout in seconds. Set to 0 to disable."
},
"responseHeader": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Time to wait for response headers in seconds. Set to 0 to disable."
},
"tlsHandshake": {
"type": "integer",
"minimum": 0,
"default": 10,
"description": "TLS handshake timeout in seconds. Set to 0 to disable."
},
"expectContinue": {
"type": "integer",
"minimum": 0,
"default": 1,
"description": "Expect-Continue timeout in seconds. Set to 0 to disable."
},
"idleConn": {
"type": "integer",
"minimum": 0,
"default": 90,
"description": "Idle connection timeout in seconds. Set to 0 to disable."
}
},
"additionalProperties": false,
"description": "Timeout settings for proxy connections."
}
},
"properties": {
@@ -48,6 +91,12 @@
"default": 120,
"description": "Number of seconds to wait for a model to be ready to serve requests."
},
"globalTTL": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
},
"logLevel": {
"type": "string",
"enum": [
@@ -87,6 +136,12 @@
"default": 1000,
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
},
"captureBuffer": {
"type": "integer",
"minimum": 0,
"default": 5,
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
},
"startPort": {
"type": "integer",
"default": 5800,
@@ -171,9 +226,9 @@
},
"ttl": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
"minimum": -1,
"default": -1,
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
},
"useModelName": {
"type": "string",
@@ -188,11 +243,26 @@
"default": "",
"pattern": "^[a-zA-Z0-9_, ]*$",
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
},
"setParams": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
},
"setParamsByID": {
"type": "object",
"additionalProperties": {
"type": "object",
"additionalProperties": true
},
"default": {},
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
}
},
"additionalProperties": false,
"default": {},
"description": "Dictionary of filter settings. Only stripParams is supported."
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
},
"metadata": {
"type": "object",
@@ -214,6 +284,9 @@
"type": "boolean",
"default": false,
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
},
"timeouts": {
"$ref": "#/definitions/timeouts"
}
}
}
@@ -273,6 +346,115 @@
},
"additionalProperties": false,
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
},
"logToStdout": {
"type": "string",
"enum": [
"proxy",
"upstream",
"both",
"none"
],
"default": "proxy",
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
},
"apiKeys": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
},
"peers": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"proxy",
"models"
],
"properties": {
"proxy": {
"type": "string",
"format": "uri",
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
},
"apiKey": {
"type": "string",
"default": "",
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
},
"models": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"description": "A list of models served by the peer."
},
"filters": {
"type": "object",
"properties": {
"stripParams": {
"type": "string",
"default": "",
"pattern": "^[a-zA-Z0-9_, ]*$",
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
},
"setParams": {
"type": "object",
"additionalProperties": true,
"default": {},
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
}
},
"additionalProperties": false,
"default": {},
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
},
"timeouts": {
"type": "object",
"properties": {
"connect": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP connection timeout in seconds."
},
"keepalive": {
"type": "integer",
"minimum": 0,
"default": 30,
"description": "TCP keepalive connection timeout in seconds."
},
"responseHeader": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Time to wait for response headers in seconds."
},
"tlsHandshake": {
"type": "integer",
"minimum": 0,
"default": 10,
"description": "TLS handshake timeout in seconds."
},
"idleConn": {
"type": "integer",
"minimum": 0,
"default": 90,
"description": "Idle connection timeout in seconds."
}
},
"additionalProperties": false,
"description": "Timeout settings for proxy connections to this peer."
}
}
},
"default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
}
}
}
}
+168 -16
View File
@@ -34,12 +34,27 @@ logLevel: info
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# captureBuffer: how many MBs to allocate for storing request/response captures
# - optional, default: 10
# - set to 0 to disable
captureBuffer: 15
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -60,6 +75,11 @@ sendLoadingState: true
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# globalTTL: the default TTL in seconds before unloading a model
# - optional, default: 0 (never automatically unload)
# - must be >= 0
globalTTL: 0
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
@@ -70,6 +90,9 @@ includeAliasesInList: false
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values can be numbers, bools, or strings
# - macros can contain other macros, but they must be defined before they are used
# - environment variables can be referenced with ${env.VAR_NAME} syntax
# - env macros are substituted first, before regular macros
# - if the env var is not set, config loading will fail with an error
macros:
# Example of a multi-line macro
"latest-llama": >
@@ -82,6 +105,24 @@ macros:
# but they must be previously declared.
"default_args": "--ctx-size ${default_ctx}"
# Example of environment variable macros
# - ${env.VAR_NAME} pulls the value from the system environment
# - useful for paths, secrets, or machine-specific configuration
"models_dir": "${env.HOME}/models"
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
# use environment variable macros to keep secrets out of the config
- "${env.API_KEY_1}"
- "${env.API_KEY_2}"
# models: a dictionary of model configurations
# - required
# - each key is the model's ID, used in API requests
@@ -90,7 +131,7 @@ macros:
# - below are examples of the all the settings a model can have
models:
# keys are the model names used in API requests
"llama":
"gpt-oss-120b":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
@@ -107,7 +148,7 @@ models:
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf
--model path/to/gpt-oss-120B.gguf
--ctx-size ${default_ctx}
--temperature ${temp}
@@ -115,13 +156,13 @@ models:
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B"
name: "gpt-oss 120B"
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing"
description: "A thinking model from OpenAI"
# env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array
@@ -136,14 +177,6 @@ models:
# - if you use a custom port in cmd this *must* be set
proxy: http://127.0.0.1:8999
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
- "gpt-3.5-turbo"
# checkEndpoint: URL path to check if the server is ready
# - optional, default: /health
# - endpoint is expected to return an HTTP 200 response
@@ -152,8 +185,10 @@ models:
checkEndpoint: /custom-endpoint
# ttl: automatically unload the model after ttl seconds
# - optional, default: 0
# - ttl values must be a value greater than 0
# - optional, default: -1 (use global default)
# - ttl values must be a value greater than or equal to 0
# - a ttl of -1 will use the global TTL value as the default
# - a ttl of 0 will mean never unload
# - a value of 0 disables automatic unloading of the model
ttl: 60
@@ -161,11 +196,11 @@ models:
# - optional, default: ""
# - useful for when the upstream server expects a specific model name that
# is different from the model's ID
useModelName: "qwen:qwq"
useModelName: "openai/gpt-oss-120B"
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only stripParams is currently supported
# - same capabilities as peer filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
@@ -175,6 +210,43 @@ models:
# - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k"
# setParams: a dictionary of parameters to set/override in requests
# - optional, default: empty dictionary
# - useful for enforcing specific parameter values
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - always runs for the model
setParams:
# Example: enforce specific sampling parameters
temperature: 0.7
top_p: 0.9
# setParamsByID: a dictionary of parameters to set based the model ID
# - optional, default: empty dictionary
# - combine with aliases to create variant behaviour without reloading the model
# - parameters are set in the request body JSON
# - run after setParams so it will override any settings
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
# - model aliases will be automatically created for each key
setParamsByID:
"${MODEL_ID}":
chat_template_kwargs:
reasoning_effort: medium
"${MODEL_ID}:high":
chat_template_kwargs:
reasoning_effort: high
"${MODEL_ID}:low":
chat_template_kwargs:
reasoning_effort: low
# aliases: alternative model names that this model configuration is used for
# - optional, default: empty array
# - aliases must be unique globally
# - useful for impersonating a specific model
aliases:
- "gpt-4o-mini"
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
@@ -212,6 +284,22 @@ models:
# - optional, default: undefined (use global setting)
sendLoadingState: false
# timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below
# - useful for models running on slower hardware that need longer timeouts
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
# - idleConn: idle connection timeout in seconds, default: 90 seconds
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 0
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
@@ -321,3 +409,67 @@ hooks:
# otherwise models will be loaded and swapped out
preload:
- "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
# - can be a string or a macro
apiKey: ${env.OPENROUTER_API_KEY}
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
# timeouts: configure proxy connection timeouts for this peer
# - optional, defaults shown below
# - useful when the peer runs on slower hardware
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
keepalive: 30
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# filters: a dictionary of filter settings for peer requests
# - optional, default: empty dictionary
# - same capabilities as model filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for removing parameters that the peer doesn't support
# - the `model` parameter can never be removed
stripParams: "temperature, top_p"
# setParams: a dictionary of parameters to set/override in requests to this peer
# - optional, default: empty dictionary
# - useful for injecting provider-specific settings like data retention policies
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce zero-data-retention for OpenRouter
provider:
data_collection: "deny"
zdr: true
+114 -20
View File
@@ -1,53 +1,134 @@
#!/bin/bash
set -euo pipefail
cd $(dirname "$0")
# use this to test locally, example:
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
# you need read:package scope on the token. Generate a personal access token with
# the scopes: gist, read:org, repo, write:packages
# then: gh auth login (and copy/paste the new token)
LOG_DEBUG=${LOG_DEBUG:-0}
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
log_debug() {
if [ "$LOG_DEBUG" = "1" ]; then
echo "[DEBUG] $*"
fi
}
log_info() {
echo "[INFO] $*"
}
ARCH=$1
PUSH_IMAGES=${2:-false}
# List of allowed architectures
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cuda13" "cpu" "rocm")
# Check if ARCH is in the allowed list
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
exit 1
fi
# Check if GITHUB_TOKEN is set and not empty
if [[ -z "$GITHUB_TOKEN" ]]; then
echo "Error: GITHUB_TOKEN is not set or is empty."
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
log_info "Error: GITHUB_TOKEN is not set or is empty."
exit 1
fi
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
# variable, this permits testing with forked llama.cpp repositories
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
# to enable easy container builds on forked repos
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
# the most recent llama-swap tag
# have to strip out the 'v' due to .tar.gz file naming
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
# Fetches the most recent llama.cpp tag matching the given prefix
# Handles pagination to search beyond the first 100 results
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
# Returns: the version number extracted from the tag
fetch_llama_tag() {
local tag_prefix=$1
local page=1
local per_page=100
while true; do
log_debug "Fetching page $page for tag prefix: $tag_prefix"
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
# Check for API errors
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
local error_msg=$(echo "$response" | jq -r '.message')
log_info "GitHub API error: $error_msg"
return 1
fi
# Check if response is empty array (no more pages)
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
log_debug "No more pages (empty response)"
return 1
fi
# Extract matching tag from this page
local found_tag=$(echo "$response" | jq -r \
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
| sort -r | head -n1)
if [ -n "$found_tag" ]; then
log_debug "Found tag: $found_tag on page $page"
echo "$found_tag" | awk -F '-' '{print $NF}'
return 0
fi
page=$((page + 1))
# Safety limit to prevent infinite loops
if [ $page -gt 50 ]; then
log_info "Reached pagination safety limit (50 pages)"
return 1
fi
done
}
if [ "$ARCH" == "cpu" ]; then
# cpu only containers just use the server tag
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
LCPP_TAG=$(fetch_llama_tag "server")
BASE_TAG=server-${LCPP_TAG}
else
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
| sort -r | head -n1 | awk -F '-' '{print $3}')
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
BASE_TAG=server-${ARCH}-${LCPP_TAG}
fi
SD_TAG=master-${ARCH}
# Abort if LCPP_TAG is empty.
if [[ -z "$LCPP_TAG" ]]; then
echo "Abort: Could not find llama-server container for arch: $ARCH"
log_info "Abort: Could not find llama-server container for arch: $ARCH"
exit 1
else
log_info "LCPP_TAG: $LCPP_TAG"
fi
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
log_info "Abort: DEBUG_ABORT_BUILD set"
exit 0
fi
for CONTAINER_TYPE in non-root root; do
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
USER_UID=0
USER_GID=0
USER_HOME=/root
@@ -60,9 +141,22 @@ for CONTAINER_TYPE in non-root root; do
USER_HOME=/app
fi
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
--build-arg BASE_IMAGE=${BASE_IMAGE} .
# For architectures with stable-diffusion.cpp support, layer sd-server on top
case "$ARCH" in
"musa" | "vulkan")
log_info "Adding sd-server to $CONTAINER_TAG"
docker build --provenance=false -f llama-swap-sd.Containerfile \
--build-arg BASE=${CONTAINER_TAG} \
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
esac
if [ "$PUSH_IMAGES" == "true" ]; then
docker push ${CONTAINER_TAG}
docker push ${CONTAINER_LATEST}
+305
View File
@@ -0,0 +1,305 @@
#!/bin/bash
#
# Build script for llama-swap-docker with commit hash pinning
#
# Usage:
# ./build-image.sh --cuda # Build CUDA image
# ./build-image.sh --vulkan # Build Vulkan image
# ./build-image.sh --cuda --no-cache # Build CUDA image without cache
# LLAMA_COMMIT_HASH=abc123 ./build-image.sh --cuda # Override llama.cpp commit
# LLAMA_COMMIT_HASH=b8429 ./build-image.sh --vulkan # Override llama.cpp release tag (vulkan uses prebuilt binaries)
# WHISPER_COMMIT_HASH=def456 ./build-image.sh --vulkan # Override whisper.cpp commit
# SD_COMMIT_HASH=ghi789 ./build-image.sh --cuda # Override stable-diffusion.cpp commit
#
# Features:
# - Auto-detects latest commit hashes from git repos
# - Builds llama-swap from local source code
# - Allows environment variable overrides for reproducible builds
# - Cache-friendly: changing commit hash busts cache appropriately
# - Supports both CUDA and Vulkan backends (requires explicit flag)
#
set -euo pipefail
# Parse command line arguments
BACKEND=""
NO_CACHE=false
if [[ $# -eq 0 ]]; then
echo "Error: No backend specified. Please use --cuda or --vulkan."
echo ""
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
echo ""
echo "Options:"
echo " --cuda Build CUDA image (NVIDIA GPUs)"
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
echo " --no-cache Force rebuild without using Docker cache"
echo " --help, -h Show this help message"
echo ""
echo "Environment variables:"
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
exit 1
fi
for arg in "$@"; do
case $arg in
--cuda)
BACKEND="cuda"
;;
--vulkan)
BACKEND="vulkan"
;;
--no-cache)
NO_CACHE=true
;;
--help|-h)
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
echo ""
echo "Options:"
echo " --cuda Build CUDA image (NVIDIA GPUs)"
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
echo " --no-cache Force rebuild without using Docker cache"
echo " --help, -h Show this help message"
echo ""
echo "Environment variables:"
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
exit 0
;;
esac
done
# Validate backend selection
if [[ -z "$BACKEND" ]]; then
echo "Error: No backend specified. Please use --cuda or --vulkan."
exit 1
fi
# Configuration
if [[ -n "${DOCKER_IMAGE_TAG:-}" ]]; then
# User provided a custom tag, use it as-is
:
elif [[ "$BACKEND" == "vulkan" ]]; then
DOCKER_IMAGE_TAG="llama-swap:vulkan"
else
DOCKER_IMAGE_TAG="llama-swap:cuda"
fi
DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}"
# Single unified Dockerfile, backend selected via build arg
DOCKERFILE="Dockerfile"
if [[ "$BACKEND" == "vulkan" ]]; then
echo "Building for: Vulkan (AMD GPUs and compatible hardware)"
else
echo "Building for: CUDA (NVIDIA GPUs)"
fi
# Git repository URLs
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
# Function to get the latest commit hash from a git repo's default branch
get_latest_commit() {
local repo_url="$1"
local branch="${2:-master}"
# Try to get the latest commit hash for the specified branch
git ls-remote --heads "${repo_url}" "${branch}" 2>/dev/null | head -1 | cut -f1
}
# Function to get the default branch name (master or main)
get_default_branch() {
local repo_url="$1"
# Check for master first
if git ls-remote --heads "${repo_url}" master &>/dev/null; then
echo "master"
elif git ls-remote --heads "${repo_url}" main &>/dev/null; then
echo "main"
else
echo "master" # fallback
fi
}
# Function to get the latest release tag from a GitHub repo
get_latest_release_tag() {
local owner_repo="$1"
curl -fsSL "https://api.github.com/repos/${owner_repo}/releases/latest" \
| grep '"tag_name"' | head -1 | cut -d'"' -f4
}
echo "=========================================="
echo "llama-swap-docker Build Script"
echo "=========================================="
echo ""
# Determine commit hashes / release tags - use env vars or auto-detect
# For vulkan builds, llama and sd use GitHub release tags (prebuilt binaries).
# For cuda builds (or whisper on any backend), use git commit hashes.
if [[ -n "${LLAMA_COMMIT_HASH:-}" ]]; then
LLAMA_HASH="${LLAMA_COMMIT_HASH}"
echo "llama.cpp: Using provided version: ${LLAMA_HASH}"
elif [[ "$BACKEND" == "vulkan" ]]; then
LLAMA_HASH=$(get_latest_release_tag "ggml-org/llama.cpp")
if [[ -z "${LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest release tag for llama.cpp" >&2
exit 1
fi
echo "llama.cpp: Auto-detected latest release tag: ${LLAMA_HASH}"
else
LLAMA_BRANCH=$(get_default_branch "${LLAMA_REPO}")
LLAMA_HASH=$(get_latest_commit "${LLAMA_REPO}" "${LLAMA_BRANCH}")
if [[ -z "${LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
exit 1
fi
echo "llama.cpp: Auto-detected latest commit (${LLAMA_BRANCH}): ${LLAMA_HASH}"
fi
if [[ -n "${WHISPER_COMMIT_HASH:-}" ]]; then
WHISPER_HASH="${WHISPER_COMMIT_HASH}"
echo "whisper.cpp: Using provided commit hash: ${WHISPER_HASH}"
else
WHISPER_BRANCH=$(get_default_branch "${WHISPER_REPO}")
WHISPER_HASH=$(get_latest_commit "${WHISPER_REPO}" "${WHISPER_BRANCH}")
if [[ -z "${WHISPER_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
exit 1
fi
echo "whisper.cpp: Auto-detected latest commit (${WHISPER_BRANCH}): ${WHISPER_HASH}"
fi
if [[ -n "${SD_COMMIT_HASH:-}" ]]; then
SD_HASH="${SD_COMMIT_HASH}"
echo "stable-diffusion.cpp: Using provided version: ${SD_HASH}"
elif [[ "$BACKEND" == "vulkan" ]]; then
SD_HASH=$(get_latest_release_tag "leejet/stable-diffusion.cpp")
if [[ -z "${SD_HASH}" ]]; then
echo "ERROR: Could not determine latest release tag for stable-diffusion.cpp" >&2
exit 1
fi
echo "stable-diffusion.cpp: Auto-detected latest release tag: ${SD_HASH}"
else
SD_BRANCH=$(get_default_branch "${SD_REPO}")
SD_HASH=$(get_latest_commit "${SD_REPO}" "${SD_BRANCH}")
if [[ -z "${SD_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
exit 1
fi
echo "stable-diffusion.cpp: Auto-detected latest commit (${SD_BRANCH}): ${SD_HASH}"
fi
echo ""
echo "=========================================="
echo "Starting Docker build..."
echo "=========================================="
echo ""
# Build the Docker image with commit hashes as build args
# Build context is the repository root (..) so the Dockerfile can access Go source
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
BUILD_ARGS=(
--build-arg "BACKEND=${BACKEND}"
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
-t "${DOCKER_IMAGE_TAG}"
-f "${SCRIPT_DIR}/${DOCKERFILE}"
)
if [[ "$NO_CACHE" == true ]]; then
BUILD_ARGS+=(--no-cache)
echo "Note: Building without cache"
fi
# Use docker buildx with a custom builder for parallelism control
# The legacy DOCKER_BUILDKIT=1 docker build doesn't respect BUILDKIT_MAX_PARALLELISM env var
# We need to use a custom builder with a buildkitd.toml config file
BUILDER_NAME="llama-swap-builder"
# Check if our custom builder exists with the right config, create/update if needed
if ! docker buildx inspect "$BUILDER_NAME" >/dev/null 2>&1; then
echo "Creating custom buildx builder with max-parallelism=1..."
# Create buildkitd.toml config file
cat > buildkitd.toml << 'BUILDKIT_EOF'
[worker.oci]
max-parallelism = 1
BUILDKIT_EOF
# Create the builder with the config
docker buildx create --name "$BUILDER_NAME" \
--driver docker-container \
--buildkitd-config buildkitd.toml \
--use
else
# Switch to our builder
docker buildx use "$BUILDER_NAME"
fi
echo "Building with sequential stages (one at a time), each using all CPU cores..."
echo "Using builder: $BUILDER_NAME"
# Use docker buildx build with --load to load the image into Docker
# The --builder flag ensures we use our custom builder with max-parallelism=1
# Build context is the repository root so we can access Go source files
docker buildx build --builder "$BUILDER_NAME" --load "${BUILD_ARGS[@]}" "${REPO_ROOT}"
echo ""
echo "=========================================="
echo "Verifying build artifacts..."
echo "=========================================="
echo ""
# Verify all expected binaries exist in the image
MISSING_BINARIES=()
for binary in llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap; do
if ! docker run --rm "${DOCKER_IMAGE_TAG}" which "${binary}" >/dev/null 2>&1; then
MISSING_BINARIES+=("${binary}")
fi
done
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
echo "ERROR: Build succeeded but the following binaries are missing from the image:"
for binary in "${MISSING_BINARIES[@]}"; do
echo " - ${binary}"
done
echo ""
echo "This usually indicates a build stage failure. Try running with --no-cache flag:"
echo " ./build-image.sh --vulkan --no-cache"
exit 1
fi
echo "All expected binaries verified: llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
echo ""
echo "=========================================="
echo "Build complete!"
echo "=========================================="
echo ""
echo "Image tag: ${DOCKER_IMAGE_TAG}"
echo ""
echo "Built with:"
echo " llama.cpp: ${LLAMA_HASH}"
echo " whisper.cpp: ${WHISPER_HASH}"
echo " stable-diffusion.cpp: ${SD_HASH}"
echo " llama-swap: $(docker run --rm "${DOCKER_IMAGE_TAG}" cat /versions.txt | grep llama-swap | cut -d' ' -f2-)"
echo ""
if [[ "$BACKEND" == "vulkan" ]]; then
echo "Run with:"
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
echo ""
echo "Note: For AMD GPUs, you may also need to mount render devices:"
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
else
echo "Run with:"
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
fi
+16 -1
View File
@@ -15,4 +15,19 @@ models:
cmd: >
/app/llama-server
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
--port 9999
--port 9999
z-image:
checkEndpoint: /
cmd: |
/app/sd-server
--listen-port 9999
--diffusion-fa
--diffusion-model /models/z_image_turbo-Q8_0.gguf
--vae /models/ae.safetensors
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
--offload-to-cpu
--cfg-scale 1.0
--height 512 --width 512
--steps 8
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
+11
View File
@@ -0,0 +1,11 @@
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
ARG SD_TAG=master-vulkan
ARG BASE=llama-swap:latest
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
FROM ${BASE}
ARG UID=10001
ARG GID=10001
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
+10 -4
View File
@@ -1,8 +1,10 @@
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
ARG BASE_TAG=server-cuda
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
FROM ${BASE_IMAGE}:${BASE_TAG}
# has to be after the FROM
ARG LS_VER=170
ARG LS_REPO=mostlygeek/llama-swap
# Set default UID/GID arguments
ARG UID=10001
@@ -27,10 +29,14 @@ RUN chown --recursive $UID:$GID $HOME /app
USER $UID:$GID
WORKDIR /app
# Add /app to PATH
ENV PATH="/app:${PATH}"
RUN \
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
+203
View File
@@ -0,0 +1,203 @@
# Unified multi-stage Dockerfile for AI inference tools
# Supports CUDA and Vulkan backends via BACKEND build arg
#
# Usage:
# docker buildx build --build-arg BACKEND=cuda -t llama-swap:unified-cuda .
# docker buildx build --build-arg BACKEND=vulkan -t llama-swap:unified-vulkan .
# docker buildx build --build-arg BACKEND=cuda --build-arg CMAKE_CUDA_ARCHITECTURES="86;89" -t llama-swap:unified-cuda .
#
# Each project has its own install script that handles cloning, building,
# and installing binaries. Build stages are independent for cache efficiency.
ARG BACKEND=cuda
# ── Builder bases ──────────────────────────────────────────────────────
FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 AS builder-base-cuda
ARG CMAKE_CUDA_ARCHITECTURES="60;61;75;86;89"
ENV DEBIAN_FRONTEND=noninteractive
ENV CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}
ENV CCACHE_DIR=/ccache
ENV CCACHE_MAXSIZE=2G
ENV PATH="/usr/lib/ccache:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
# ──
FROM ubuntu:24.04 AS builder-base-vulkan
ENV DEBIAN_FRONTEND=noninteractive
ENV CCACHE_DIR=/ccache
ENV CCACHE_MAXSIZE=2G
ENV PATH="/usr/lib/ccache:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget software-properties-common \
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
# ── Select builder base by BACKEND ────────────────────────────────────
FROM builder-base-${BACKEND} AS builder-base
# ── Build whisper.cpp (fastest build, run first) ──────────────────────
FROM builder-base AS whisper-build
ARG BACKEND=cuda
ARG WHISPER_COMMIT_HASH=master
COPY install-whisper.sh /build/
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
--mount=type=cache,id=whisper-${BACKEND},target=/src/whisper.cpp/build \
BACKEND=${BACKEND} bash /build/install-whisper.sh "${WHISPER_COMMIT_HASH}"
# ── Build stable-diffusion.cpp ────────────────────────────────────────
FROM builder-base AS sd-build
ARG BACKEND=cuda
ARG SD_COMMIT_HASH=master
COPY install-sd.sh /build/
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
--mount=type=cache,id=sd-${BACKEND},target=/src/stable-diffusion.cpp/build \
BACKEND=${BACKEND} bash /build/install-sd.sh "${SD_COMMIT_HASH}"
# ── Build llama.cpp (slowest build, run last) ─────────────────────────
FROM builder-base AS llama-build
ARG BACKEND=cuda
ARG LLAMA_COMMIT_HASH=master
COPY install-llama.sh /build/
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
--mount=type=cache,id=llama-${BACKEND},target=/src/llama.cpp/build \
BACKEND=${BACKEND} bash /build/install-llama.sh "${LLAMA_COMMIT_HASH}"
# ── Build ik_llama.cpp (CUDA only) ────────────────────────────────────
#
# Two named stages allow ARG BACKEND to select at build time:
# - ik-llama-cuda : real build (from builder-base-cuda)
# - ik-llama-vulkan: no-op (empty /install/bin, skips CUDA pull entirely)
# BuildKit only evaluates the selected branch, so vulkan builds never
# pull nvidia/cuda:*-devel or compile ik_llama.cpp.
FROM builder-base-vulkan AS ik-llama-vulkan
RUN mkdir -p /install/bin
FROM builder-base-cuda AS ik-llama-cuda
ARG IK_LLAMA_COMMIT_HASH=main
COPY install-ik-llama.sh /build/
RUN --mount=type=cache,id=ccache-cuda,target=/ccache \
--mount=type=cache,id=ik-llama-cuda,target=/src/ik_llama.cpp/build \
bash /build/install-ik-llama.sh "${IK_LLAMA_COMMIT_HASH}"
ARG BACKEND=cuda
FROM ik-llama-${BACKEND} AS ik-llama-build
# ── Download llama-swap release binary ────────────────────────────────
FROM builder-base AS llama-swap-download
ARG LS_VERSION=latest
COPY install-llama-swap.sh /build/
RUN bash /build/install-llama-swap.sh "${LS_VERSION}"
# ── Runtime bases ─────────────────────────────────────────────────────
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 AS runtime-cuda
ENV DEBIAN_FRONTEND=noninteractive
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV PATH="/usr/local/bin:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 python3 curl ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# CUDA stub drivers for container compatibility
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
# ──
FROM ubuntu:24.04 AS runtime-vulkan
ENV DEBIAN_FRONTEND=noninteractive
ENV PATH="/usr/local/bin:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 libvulkan1 mesa-vulkan-drivers \
python3 curl ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# ── Select runtime base by BACKEND ────────────────────────────────────
FROM runtime-${BACKEND} AS runtime
ARG BACKEND=cuda
ARG LLAMA_COMMIT_HASH=unknown
ARG WHISPER_COMMIT_HASH=unknown
ARG SD_COMMIT_HASH=unknown
ARG IK_LLAMA_COMMIT_HASH=unknown
ARG RUN_UID=0
RUN apt-get update && apt-get install -y --no-install-recommends \
python3-numpy python3-sentencepiece \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user when RUN_UID != 0
RUN if [ "$RUN_UID" != "0" ]; then \
groupadd --system --gid $RUN_UID llama-swap && \
useradd --system --uid $RUN_UID --gid $RUN_UID \
--home /app --shell /sbin/nologin llama-swap; \
fi && \
mkdir -p /etc/llama-swap/config && \
chown -R ${RUN_UID}:${RUN_UID} /etc/llama-swap
WORKDIR /app
# Copy whisper.cpp binaries and libraries
COPY --from=whisper-build /install/bin/whisper-server /usr/local/bin/
COPY --from=whisper-build /install/bin/whisper-cli /usr/local/bin/
COPY --from=whisper-build /install/lib/ /usr/local/lib/
# Copy stable-diffusion.cpp binaries and libraries
COPY --from=sd-build /install/bin/sd-server /usr/local/bin/
COPY --from=sd-build /install/bin/sd-cli /usr/local/bin/
COPY --from=sd-build /install/lib/ /usr/local/lib/
# Copy llama.cpp binaries (statically linked)
COPY --from=llama-build /install/bin/llama-server /usr/local/bin/
COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
# Copy llama-swap binary
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
RUN ldconfig
COPY config.example.yaml /etc/llama-swap/config/config.yaml
# Version tracking
RUN echo "llama.cpp: ${LLAMA_COMMIT_HASH}" > /versions.txt && \
echo "whisper.cpp: ${WHISPER_COMMIT_HASH}" >> /versions.txt && \
echo "stable-diffusion.cpp: ${SD_COMMIT_HASH}" >> /versions.txt && \
echo "ik_llama.cpp: ${IK_LLAMA_COMMIT_HASH}" >> /versions.txt && \
echo "llama-swap: $(cat /tmp/llama-swap-version)" >> /versions.txt && \
echo "backend: ${BACKEND}" >> /versions.txt && \
echo "build_timestamp: $(date -u +%Y-%m-%dT%H:%M:%SZ)" >> /versions.txt
RUN mkdir -p /models && chown ${RUN_UID}:${RUN_UID} /models
WORKDIR /models
USER ${RUN_UID}
ENTRYPOINT ["llama-swap"]
CMD ["-config", "/etc/llama-swap/config/config.yaml", "-listen", "0.0.0.0:8080"]
+8
View File
@@ -0,0 +1,8 @@
# Unified Docker Container
These scripts create a custom llama-swap container that contains:
- llama-server for LLMs, rerank and embedding model support
- sd-server (stable-diffusion.cpp) for image generation
- whisper.cpp for ASR
+303
View File
@@ -0,0 +1,303 @@
#!/bin/bash
#
# Build script for unified container with version pinning
#
# Usage:
# ./build-image.sh --cuda # Build CUDA image
# ./build-image.sh --vulkan # Build Vulkan image
# ./build-image.sh --cuda --no-cache # Build without cache
# LLAMA_REF=b1234 ./build-image.sh --vulkan # Pin llama.cpp to a commit hash
# LLAMA_REF=v1.2.3 ./build-image.sh --cuda # Pin llama.cpp to a tag
# WHISPER_REF=v1.0.0 ./build-image.sh --vulkan # Pin whisper.cpp to a tag
# SD_REF=master ./build-image.sh --cuda # Pin stable-diffusion.cpp to a branch
# LS_VERSION=170 ./build-image.sh --cuda # Override llama-swap version
# IK_LLAMA_REF=main ./build-image.sh --cuda # Pin ik_llama.cpp to main branch (CUDA only)
#
set -euo pipefail
BACKEND=""
NO_CACHE=false
for arg in "$@"; do
case $arg in
--cuda)
BACKEND="cuda"
;;
--vulkan)
BACKEND="vulkan"
;;
--no-cache)
NO_CACHE=true
;;
--help|-h)
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
echo ""
echo "Options:"
echo " --cuda Build CUDA image (NVIDIA GPUs)"
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
echo " --no-cache Force rebuild without using Docker cache"
echo " --help, -h Show this help message"
echo ""
echo "Environment variables:"
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:unified-cuda or llama-swap:unified-vulkan)"
echo " LLAMA_REF Pin llama.cpp to a commit, tag, or branch"
echo " WHISPER_REF Pin whisper.cpp to a commit, tag, or branch"
echo " SD_REF Pin stable-diffusion.cpp to a commit, tag, or branch"
echo " IK_LLAMA_REF Pin ik_llama.cpp to a commit, tag, or branch (CUDA only)"
echo " LS_VERSION Override llama-swap version (e.g., '170' or 'latest')"
exit 0
;;
esac
done
if [[ -z "$BACKEND" ]]; then
echo "Error: No backend specified. Please use --cuda or --vulkan."
echo ""
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
exit 1
fi
DOCKER_IMAGE_TAG="${DOCKER_IMAGE_TAG:-llama-swap:unified-${BACKEND}}"
# Git repository URLs
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
LLAMA_SWAP_REPO="https://github.com/mostlygeek/llama-swap.git"
IK_LLAMA_REPO="https://github.com/ikawrakow/ik_llama.cpp.git"
# Resolve a git ref (commit hash, tag, or branch) to a full commit hash.
# Requires only: git, network access to the remote.
resolve_ref() {
local repo_url="$1"
local ref="$2"
# Full 40-char SHA — use as-is
if [[ "${ref}" =~ ^[0-9a-f]{40}$ ]]; then
echo "${ref}"
return
fi
# Try tag then branch (exact match)
local hash
hash=$(git ls-remote "${repo_url}" "refs/tags/${ref}" "refs/heads/${ref}" 2>/dev/null | head -1 | cut -f1)
if [[ -n "${hash}" ]]; then
echo "${hash}"
return
fi
# Short hash (7+ chars): scan all refs for a SHA with this prefix
if [[ "${ref}" =~ ^[0-9a-f]{7,}$ ]]; then
hash=$(git ls-remote "${repo_url}" 2>/dev/null | grep "^${ref}" | head -1 | cut -f1)
if [[ -n "${hash}" ]]; then
echo "${hash}"
return
fi
fi
echo "ERROR: Could not resolve ref '${ref}' for ${repo_url}" >&2
if [[ "${ref}" =~ ^[0-9a-f]+$ && ${#ref} -lt 7 ]]; then
echo " Short hashes must be at least 7 characters (got ${#ref})." >&2
else
echo " Tried: tag, branch, git ls-remote prefix match" >&2
fi
echo " Use a full 40-char SHA, a tag name, a branch name, or a 7+ char short hash." >&2
return 1
}
# Resolve HEAD of a repo without needing to know the default branch name.
get_latest_hash() {
git ls-remote "${1}" HEAD 2>/dev/null | head -1 | cut -f1
}
echo "=========================================="
echo "llama-swap Unified Build (${BACKEND})"
echo "=========================================="
echo ""
# Resolve llama.cpp ref
if [[ -n "${LLAMA_REF:-}" ]]; then
LLAMA_HASH=$(resolve_ref "${LLAMA_REPO}" "${LLAMA_REF}") || exit 1
echo "llama.cpp: ${LLAMA_REF} -> ${LLAMA_HASH}"
else
LLAMA_HASH=$(get_latest_hash "${LLAMA_REPO}")
if [[ -z "${LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
exit 1
fi
echo "llama.cpp: latest HEAD: ${LLAMA_HASH}"
fi
# Resolve whisper.cpp ref
if [[ -n "${WHISPER_REF:-}" ]]; then
WHISPER_HASH=$(resolve_ref "${WHISPER_REPO}" "${WHISPER_REF}") || exit 1
echo "whisper.cpp: ${WHISPER_REF} -> ${WHISPER_HASH}"
else
WHISPER_HASH=$(get_latest_hash "${WHISPER_REPO}")
if [[ -z "${WHISPER_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
exit 1
fi
echo "whisper.cpp: latest HEAD: ${WHISPER_HASH}"
fi
# Resolve stable-diffusion.cpp ref
if [[ -n "${SD_REF:-}" ]]; then
SD_HASH=$(resolve_ref "${SD_REPO}" "${SD_REF}") || exit 1
echo "stable-diffusion.cpp: ${SD_REF} -> ${SD_HASH}"
else
SD_HASH=$(get_latest_hash "${SD_REPO}")
if [[ -z "${SD_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
exit 1
fi
echo "stable-diffusion.cpp: latest HEAD: ${SD_HASH}"
fi
# Resolve ik_llama.cpp ref (CUDA only)
if [[ "$BACKEND" == "cuda" ]]; then
if [[ -n "${IK_LLAMA_REF:-}" ]]; then
IK_LLAMA_HASH=$(resolve_ref "${IK_LLAMA_REPO}" "${IK_LLAMA_REF}") || exit 1
echo "ik_llama.cpp: ${IK_LLAMA_REF} -> ${IK_LLAMA_HASH}"
else
IK_LLAMA_HASH=$(get_latest_hash "${IK_LLAMA_REPO}")
if [[ -z "${IK_LLAMA_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for ik_llama.cpp" >&2
exit 1
fi
echo "ik_llama.cpp: latest HEAD: ${IK_LLAMA_HASH}"
fi
else
IK_LLAMA_HASH="n/a"
echo "ik_llama.cpp: skipped (vulkan build)"
fi
# Resolve llama-swap ref
if [[ -n "${LS_VERSION:-}" ]]; then
LS_HASH=$(resolve_ref "${LLAMA_SWAP_REPO}" "${LS_VERSION}") || exit 1
echo "llama-swap: ${LS_VERSION} -> ${LS_HASH}"
else
LS_HASH=$(get_latest_hash "${LLAMA_SWAP_REPO}")
if [[ -z "${LS_HASH}" ]]; then
echo "ERROR: Could not determine latest commit for llama-swap" >&2
exit 1
fi
echo "llama-swap: latest HEAD: ${LS_HASH}"
fi
echo ""
echo "=========================================="
echo "Starting Docker build..."
echo "=========================================="
echo ""
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BUILD_ARGS=(
--build-arg "BACKEND=${BACKEND}"
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
--build-arg "LS_VERSION=${LS_HASH}"
-t "${DOCKER_IMAGE_TAG}"
-f "${SCRIPT_DIR}/Dockerfile"
)
if [[ "$NO_CACHE" == true ]]; then
BUILD_ARGS+=(--no-cache)
echo "Note: Building without cache"
elif [[ "${GITHUB_ACTIONS:-}" == "true" && "${ACT:-}" != "true" ]]; then
CACHE_REF="ghcr.io/mostlygeek/llama-swap:unified-${BACKEND}-cache"
BUILD_ARGS+=(
--cache-from "type=registry,ref=${CACHE_REF}"
--cache-to "type=registry,ref=${CACHE_REF},mode=max"
)
echo "Note: Using registry cache (${CACHE_REF})"
fi
DOCKER_BUILDKIT=1 docker buildx build --load "${BUILD_ARGS[@]}" "${SCRIPT_DIR}"
echo ""
echo "=========================================="
echo "Verifying build artifacts..."
echo "=========================================="
echo ""
EXPECTED_BINARIES=(llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap)
if [[ "$BACKEND" == "cuda" ]]; then
EXPECTED_BINARIES+=(ik-llama-server)
fi
MISSING_BINARIES=()
for binary in "${EXPECTED_BINARIES[@]}"; do
if ! docker run --rm --entrypoint which "${DOCKER_IMAGE_TAG}" "${binary}" >/dev/null 2>&1; then
MISSING_BINARIES+=("${binary}")
fi
done
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
echo "ERROR: Build succeeded but the following binaries are missing:"
for binary in "${MISSING_BINARIES[@]}"; do
echo " - ${binary}"
done
echo ""
echo "Try running with --no-cache flag:"
echo " ./build-image.sh --${BACKEND} --no-cache"
exit 1
fi
VERIFIED_LIST="llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
if [[ "$BACKEND" == "cuda" ]]; then
VERIFIED_LIST="${VERIFIED_LIST}, ik-llama-server"
fi
echo "All expected binaries verified: ${VERIFIED_LIST}"
echo ""
echo "=========================================="
echo "Building rootless image..."
echo "=========================================="
echo ""
ROOTLESS_TAG="${DOCKER_IMAGE_TAG}-rootless"
docker buildx build --load -t "${ROOTLESS_TAG}" - <<EOF
FROM ${DOCKER_IMAGE_TAG}
USER root
RUN groupadd --system --gid 10001 llama-swap && \\
useradd --system --uid 10001 --gid 10001 \\
--home /app --shell /sbin/nologin llama-swap && \\
chown -R 10001:10001 /etc/llama-swap /models
USER 10001
EOF
echo "Rootless image built: ${ROOTLESS_TAG}"
echo ""
echo "=========================================="
echo "Build complete!"
echo "=========================================="
echo ""
echo "Image tags:"
echo " ${DOCKER_IMAGE_TAG}"
echo " ${ROOTLESS_TAG}"
echo ""
echo "Built with:"
echo " llama.cpp: ${LLAMA_HASH}"
echo " whisper.cpp: ${WHISPER_HASH}"
echo " stable-diffusion.cpp: ${SD_HASH}"
if [[ "$BACKEND" == "cuda" ]]; then
echo " ik_llama.cpp: ${IK_LLAMA_HASH}"
fi
echo " llama-swap: $(docker run --rm --entrypoint cat "${DOCKER_IMAGE_TAG}" /versions.txt | grep llama-swap | cut -d' ' -f2-)"
echo ""
if [[ "$BACKEND" == "vulkan" ]]; then
echo "Run with:"
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
echo ""
echo "Note: For AMD GPUs, you may also need:"
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
else
echo "Run with:"
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
fi
+33
View File
@@ -0,0 +1,33 @@
# placeholder example configuration
healthCheckTimeout: 300
logRequests: true
models:
"llama":
cmd: >
llama-server
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
--port ${PORT}
"whisper":
checkEndpoint: /v1/audio/transcriptions/
cmd: >
whisper-server
--port ${PORT}
--m /models/whisper.bin
--flash-attn
--request-path /v1/audio/transcriptions --inference-path ""
"image":
checkEndpoint: /
cmd: |
/app/sd-server
--listen-port 9999
--diffusion-fa
--diffusion-model /models/z_image_turbo-Q8_0.gguf
--vae /models/ae.safetensors
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
--offload-to-cpu
--cfg-scale 1.0
--height 512 --width 512
--steps 8
+48
View File
@@ -0,0 +1,48 @@
#!/bin/bash
# Install ik_llama.cpp - clone, build, and install binaries
# Usage: ./install-ik-llama.sh <commit_hash>
# Note: CUDA only; always built against builder-base-cuda
set -e
COMMIT_HASH="${1:-main}"
mkdir -p /install/bin
# Clone and checkout (init-based so cache-mounted build dir doesn't break clone)
echo "=== Cloning ik_llama.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/ik_llama.cpp
cd /src/ik_llama.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/ikawrakow/ik_llama.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DGGML_CUDA=ON
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda -Wl,--allow-shlib-undefined"
)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building ik_llama.cpp ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target llama-server
if [ ! -f "build/bin/llama-server" ]; then
echo "FATAL: llama-server not found in build/bin/" >&2
exit 1
fi
# Install as ik-llama-server to avoid collision with llama.cpp's llama-server
cp "build/bin/llama-server" "/install/bin/ik-llama-server"
echo "=== ik_llama.cpp build complete ==="
ls -la /install/bin/
+59
View File
@@ -0,0 +1,59 @@
#!/bin/bash
# Install llama-swap - download latest release binary from GitHub
# Usage: ./install-llama-swap.sh [version]
# version: release version number (e.g., "170") or "latest" (default)
set -e
VERSION="${1:-latest}"
REPO="mostlygeek/llama-swap"
mkdir -p /install/bin
# If a full commit hash is given, find the release tag that points to it
if echo "${VERSION}" | grep -qE '^[0-9a-f]{40}$'; then
echo "=== Resolving commit ${VERSION:0:7} to release tag ==="
TAG=$(git ls-remote --tags "https://github.com/${REPO}.git" 2>/dev/null \
| grep "^${VERSION}" | sed 's|.*refs/tags/||' | grep -v '\^{}' | head -1)
if [ -n "${TAG}" ]; then
echo "Resolved to tag: ${TAG}"
VERSION="${TAG#v}"
else
echo "No release tag found for commit ${VERSION:0:7}, using latest"
VERSION="latest"
fi
fi
# Strip leading 'v' prefix so both "198" and "v198" work
VERSION="${VERSION#v}"
# Resolve "latest" to actual version number
if [ "$VERSION" = "latest" ]; then
echo "=== Resolving latest llama-swap release ==="
VERSION=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" \
| grep '"tag_name"' | head -1 | cut -d'"' -f4 | sed 's/^v//')
if [ -z "$VERSION" ]; then
echo "FATAL: Could not determine latest release version" >&2
exit 1
fi
echo "Latest version: ${VERSION}"
fi
# Download and extract
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz"
echo "=== Downloading llama-swap v${VERSION} ==="
echo "URL: $URL"
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
tar -xzf /tmp/llama-swap.tar.gz -C /install/bin/
rm /tmp/llama-swap.tar.gz
# Validate
if [ ! -x "/install/bin/llama-swap" ]; then
echo "FATAL: llama-swap binary not found or not executable" >&2
ls -la /install/bin/ >&2
exit 1
fi
echo "$VERSION" > /install/llama-swap-version
echo "=== llama-swap v${VERSION} installed ==="
ls -la /install/bin/llama-swap
+63
View File
@@ -0,0 +1,63 @@
#!/bin/bash
# Install llama.cpp - clone, build, and install binaries
# Usage: BACKEND=cuda|vulkan ./install-llama.sh <commit_hash>
set -e
COMMIT_HASH="${1:-master}"
BACKEND="${BACKEND:-cuda}"
mkdir -p /install/bin
# Clone and checkout (init-based so cache-mounted /src/llama.cpp/build dir doesn't break clone)
echo "=== Cloning llama.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/llama.cpp
cd /src/llama.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/ggml-org/llama.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
# Common cmake flags
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DLLAMA_BUILD_TESTS=OFF
)
if [ "$BACKEND" = "cuda" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=ON
-DGGML_VULKAN=OFF
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
)
elif [ "$BACKEND" = "vulkan" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=OFF
-DGGML_VULKAN=ON
)
fi
TARGETS=(llama-cli llama-server)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building llama.cpp for ${BACKEND} ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
for bin in "${TARGETS[@]}"; do
if [ ! -f "build/bin/$bin" ]; then
echo "FATAL: $bin not found in build/bin/" >&2
exit 1
fi
cp "build/bin/$bin" "/install/bin/"
done
echo "=== llama.cpp build complete ==="
ls -la /install/bin/
+68
View File
@@ -0,0 +1,68 @@
#!/bin/bash
# Install stable-diffusion.cpp - clone, build, and install binaries and library
# Usage: BACKEND=cuda|vulkan ./install-sd.sh <commit_hash>
set -e
COMMIT_HASH="${1:-master}"
BACKEND="${BACKEND:-cuda}"
mkdir -p /install/bin /install/lib
# Clone and checkout (init-based so cache-mounted /src/stable-diffusion.cpp/build dir doesn't break clone)
echo "=== Cloning stable-diffusion.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/stable-diffusion.cpp
cd /src/stable-diffusion.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/leejet/stable-diffusion.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
git submodule update --init --recursive --depth=1
# Common cmake flags
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
-DSD_BUILD_EXAMPLES=ON
)
if [ "$BACKEND" = "cuda" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=ON
-DGGML_VULKAN=OFF
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
-DSD_CUDA=ON
)
elif [ "$BACKEND" = "vulkan" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=OFF
-DGGML_VULKAN=ON
-DSD_VULKAN=ON
)
fi
TARGETS=(stable-diffusion sd-cli sd-server)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building stable-diffusion.cpp for ${BACKEND} ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
for bin in sd-cli sd-server; do
if [ ! -f "build/bin/$bin" ]; then
echo "FATAL: $bin not found in build/bin/" >&2
exit 1
fi
cp "build/bin/$bin" "/install/bin/"
done
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
echo "=== stable-diffusion.cpp build complete ==="
ls -la /install/bin/ /install/lib/
+64
View File
@@ -0,0 +1,64 @@
#!/bin/bash
# Install whisper.cpp - clone, build, and install binaries
# Usage: BACKEND=cuda|vulkan ./install-whisper.sh <commit_hash>
set -e
COMMIT_HASH="${1:-master}"
BACKEND="${BACKEND:-cuda}"
mkdir -p /install/bin /install/lib
# Clone and checkout (init-based so cache-mounted /src/whisper.cpp/build dir doesn't break clone)
echo "=== Cloning whisper.cpp at ${COMMIT_HASH} ==="
mkdir -p /src/whisper.cpp
cd /src/whisper.cpp
if [ ! -d .git ]; then
git init
git remote add origin https://github.com/ggml-org/whisper.cpp.git
fi
git fetch --depth=1 origin "${COMMIT_HASH}"
git checkout FETCH_HEAD
# Common cmake flags
CMAKE_FLAGS=(
-DGGML_NATIVE=OFF
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER_LAUNCHER=ccache
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
)
if [ "$BACKEND" = "cuda" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=ON
-DGGML_VULKAN=OFF
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
)
elif [ "$BACKEND" = "vulkan" ]; then
CMAKE_FLAGS+=(
-DGGML_CUDA=OFF
-DGGML_VULKAN=ON
)
fi
TARGETS=(whisper-cli whisper-server)
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
echo "=== Building whisper.cpp for ${BACKEND} ==="
cmake -B build "${CMAKE_FLAGS[@]}"
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
for bin in "${TARGETS[@]}"; do
if [ ! -f "build/bin/$bin" ]; then
echo "FATAL: $bin not found in build/bin/" >&2
exit 1
fi
cp "build/bin/$bin" "/install/bin/"
done
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
echo "=== whisper.cpp build complete ==="
ls -la /install/bin/

Before

Width:  |  Height:  |  Size: 261 KiB

After

Width:  |  Height:  |  Size: 261 KiB

Before

Width:  |  Height:  |  Size: 351 KiB

After

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

+116 -1
View File
@@ -86,9 +86,12 @@ llama-swap supports many more features to customize how you want to manage your
## Full Configuration Example
> [!NOTE]
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
```yaml
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example
# -------------------------------------
#
@@ -114,6 +117,24 @@ healthCheckTimeout: 500
# - Valid log levels: debug, info, warn, error
logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
@@ -126,6 +147,30 @@ metricsMaxInMemory: 1000
# - it is automatically incremented for every model that uses it
startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
# macros: a dictionary of string substitutions
# - optional, default: empty dictionary
# - macros are reusable snippets
@@ -274,6 +319,33 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0
# timeouts: configure proxy connection timeouts for this model
# - optional, defaults shown below
# - useful for models on slower hardware that need longer timeouts
# - increase responseHeader to avoid "timeout awaiting response headers" errors
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
# connect: TCP connection timeout in seconds
# - default: 30
connect: 30
# responseHeader: time to wait for response headers in seconds
# - default: 60
# - for slow image generation or large models, consider increasing to 300+ seconds
responseHeader: 60
# tlsHandshake: TLS handshake timeout in seconds
# - default: 10
tlsHandshake: 10
# idleConn: idle connection timeout in seconds
# - default: 90
idleConn: 90
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# Unlisted model example:
"qwen-unlisted":
# unlisted: boolean, true or false
@@ -383,4 +455,47 @@ hooks:
# otherwise models will be loaded and swapped out
preload:
- "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# timeouts: configure proxy connection timeouts for this peer
# - optional, defaults shown below
# - useful when the peer runs on slower hardware
# - set any value to 0 to disable that timeout (not recommended)
timeouts:
connect: 30
responseHeader: 60
tlsHandshake: 10
idleConn: 90
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
apiKey: sk-your-openrouter-key
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
```
+9
View File
@@ -0,0 +1,9 @@
## Container Security
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
+5 -5
View File
@@ -1,6 +1,6 @@
module github.com/mostlygeek/llama-swap
go 1.23.0
go 1.26.1
require (
github.com/billziss-gh/golib v0.2.0
@@ -37,9 +37,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)
+8 -8
View File
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
+247 -62
View File
@@ -15,6 +15,12 @@ import (
)
const DEFAULT_GROUP_ID = "(default)"
const (
LogToStdoutProxy = "proxy"
LogToStdoutUpstream = "upstream"
LogToStdoutBoth = "both"
LogToStdoutNone = "none"
)
type MacroEntry struct {
Name string
@@ -81,6 +87,7 @@ type GroupConfig struct {
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
)
// set default values for GroupConfig
@@ -114,7 +121,10 @@ type Config struct {
LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"`
LogTimeFormat string `yaml:"logTimeFormat"`
LogToStdout string `yaml:"logToStdout"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
CaptureBuffer int `yaml:"captureBuffer"`
GlobalTTL int `yaml:"globalTTL"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
@@ -136,6 +146,12 @@ type Config struct {
// present aliases to /v1/models OpenAI API listing
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
// support API keys, see issue #433, #50, #251
RequiredAPIKeys []string `yaml:"apiKeys"`
// support remote peers, see issue #433, #296
Peers PeerDictionaryConfig `yaml:"peers"`
}
func (c *Config) RealModelName(search string) (string, bool) {
@@ -170,22 +186,31 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
if err != nil {
return Config{}, err
}
yamlStr := string(data)
// default configuration values
// Phase 1: Substitute all ${env.VAR} macros at string level
// This is safe because env values are simple strings without YAML formatting
yamlStr, err = substituteEnvMacros(yamlStr)
if err != nil {
return Config{}, err
}
// Unmarshal into full Config with defaults
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
GlobalTTL: 0,
}
err = yaml.Unmarshal(data, &config)
if err != nil {
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
// set a minimum of 15 seconds
config.HealthCheckTimeout = 15
}
@@ -193,6 +218,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
if config.GlobalTTL < 0 {
return Config{}, fmt.Errorf("globalTTL must be >= 0")
}
switch config.LogToStdout {
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
default:
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
@@ -204,55 +239,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
/* check macro constraint rules:
- name must fit the regex ^[a-zA-Z0-9_-]+$
- names must be less than 64 characters (no reason, just cause)
- name can not be any reserved macros: PORT, MODEL_ID
- macro values must be less than 1024 characters
*/
// Validate global macros
for _, macro := range config.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, err
}
}
// Get and sort all model IDs first, makes testing more consistent
// Get and sort all model IDs for consistent port assignment
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds) // This guarantees stable iteration order
sort.Strings(modelIds)
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// Strip comments from command fields before macro expansion
// Strip comments from command fields
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// validate model macros
// set model TTL to globalTTL it is the default value
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
modelConfig.UnloadAfter = config.GlobalTTL
}
if modelConfig.UnloadAfter < 0 {
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
}
// Validate model macros
for _, macro := range modelConfig.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
// Add global macros first
mergedMacros = append(mergedMacros, config.Macros...)
// Add model macros (can override global)
// Add model macros (override globals with same name)
for _, entry := range modelConfig.Macros {
// Remove any existing global macro with same name
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry // Override
mergedMacros[i] = entry
found = true
break
}
@@ -262,23 +297,40 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
// This allows later macros to reference earlier ones
// Substitute remaining macros in model fields (LIFO order)
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
// Substitute in command fields
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
// Substitute in metadata (recursive)
// Substitute macros in SetParamsByID keys and values
if len(modelConfig.Filters.SetParamsByID) > 0 {
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
for key, paramMap := range modelConfig.Filters.SetParamsByID {
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
}
newParamMap, ok := newValAny.(map[string]any)
if !ok {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
}
newSetParamsByID[newKey] = newParamMap
}
modelConfig.Filters.SetParamsByID = newSetParamsByID
}
// Substitute in metadata (type-preserving)
if len(modelConfig.Metadata) > 0 {
var err error
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
@@ -287,29 +339,25 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// Final pass: check if PORT macro is needed after macro expansion
// ${PORT} is a resource on the local machine so a new port is only allocated
// if it is required in either cmd or proxy keys
// Handle PORT macro - only allocate if cmd uses it
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort { // either has it
if !cmdHasPort && proxyHasPort { // but both don't have it
if cmdHasPort || proxyHasPort {
if !cmdHasPort && proxyHasPort {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// Add PORT macro and substitute it
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
macroSlug := "${PORT}"
macroStr := fmt.Sprintf("%v", nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
// Substitute PORT in metadata
if len(modelConfig.Metadata) > 0 {
var err error
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
@@ -319,13 +367,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
nextPort++
}
// make sure there are no unknown macros that have not been replaced
// Validate no unknown macros remain
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
"name": modelConfig.Name,
"description": modelConfig.Description,
}
for fieldName, fieldValue := range fieldMap {
@@ -333,35 +383,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // this is ok, has to be replaced by process later
continue // replaced at runtime
}
// Reserved macros are always valid (they should have been substituted already)
if macroName == "PORT" || macroName == "MODEL_ID" {
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
}
// Any other macro is unknown
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
// Check for unknown macros in metadata
if len(modelConfig.Metadata) > 0 {
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
return Config{}, err
}
}
// Validate the proxy URL.
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf(
"model %s: invalid proxy URL: %w", modelId, err,
)
// Validate SetParamsByID keys and values
for key, paramMap := range modelConfig.Filters.SetParamsByID {
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
}
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
return Config{}, err
}
}
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
for key := range modelConfig.Filters.SetParamsByID {
if key == modelId {
continue
}
if _, exists := config.Models[key]; exists {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
}
if existingModel, exists := config.aliases[key]; exists {
if existingModel != modelId {
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
}
continue // already registered as explicit alias for this model
}
config.aliases[key] = modelId
modelConfig.Aliases = append(modelConfig.Aliases, key)
}
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
}
// if sendLoadingState is nil, set it to the global config value
// see #366
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState // copy it
v := config.SendLoadingState
modelConfig.SendLoadingState = &v
}
@@ -369,18 +439,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in
// Validate group members
memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
// Check for duplicates within this group
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
// Check if member is used in another group
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
@@ -388,7 +457,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// clean up hooks preload
// Clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
@@ -400,10 +469,56 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
// Validate API keys (env macros already substituted at string level)
for i, apikey := range config.RequiredAPIKeys {
if apikey == "" {
return Config{}, fmt.Errorf("empty api key found in apiKeys")
}
if strings.Contains(apikey, " ") {
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
}
config.RequiredAPIKeys[i] = apikey
}
// Process peers with global macro substitution
for peerName, peerConfig := range config.Peers {
// Substitute global macros (LIFO order)
for i := len(config.Macros) - 1; i >= 0; i-- {
entry := config.Macros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in setParams (type-preserving)
if len(peerConfig.Filters.SetParams) > 0 {
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
}
peerConfig.Filters.SetParams = result.(map[string]any)
}
}
// Validate no unknown macros remain
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
}
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
}
if len(peerConfig.Filters.SetParams) > 0 {
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
return Config{}, err
}
}
config.Peers[peerName] = peerConfig
}
return config, nil
}
@@ -534,20 +649,26 @@ func validateMacro(name string, value any) error {
return nil
}
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
func validateMetadataForUnknownMacros(value any, modelId string) error {
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
func validateNestedForUnknownMacros(value any, context string) error {
switch v := value.(type) {
case string:
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
for _, match := range matches {
macroName := match[1]
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
}
// Check for unsubstituted env macros
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
for _, match := range envMatches {
varName := match[1]
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
}
return nil
case map[string]any:
for _, val := range v {
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
@@ -555,7 +676,7 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
case []any:
for _, val := range v {
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
@@ -614,3 +735,67 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e
return value, nil
}
}
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
// Returns error if any referenced env var is not set or contains invalid characters.
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
// (which strips comments) and only checking the comment-free version for macros.
func substituteEnvMacros(s string) (string, error) {
// Unmarshal and remarshal to strip YAML comments
var raw any
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
// If YAML is invalid, fall back to scanning the original string
// so the user gets the env var error rather than a confusing YAML parse error
return substituteEnvMacrosInString(s, s)
}
clean, err := yaml.Marshal(raw)
if err != nil {
return substituteEnvMacrosInString(s, s)
}
return substituteEnvMacrosInString(s, string(clean))
}
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
// them in target. This separation allows scanning comment-free YAML while
// substituting in the original string.
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
result := target
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
for _, match := range matches {
fullMatch := match[0] // ${env.VAR_NAME}
varName := match[1] // VAR_NAME
value, exists := os.LookupEnv(varName)
if !exists {
return "", fmt.Errorf("environment variable '%s' is not set", varName)
}
// Sanitize the value for safe YAML substitution
value, err := sanitizeEnvValueForYAML(value, varName)
if err != nil {
return "", err
}
result = strings.ReplaceAll(result, fullMatch, value)
}
return result, nil
}
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
// for compatibility with double-quoted YAML strings.
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
// Reject values that would break YAML structure regardless of quoting context
if strings.ContainsAny(value, "\n\r\x00") {
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
}
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
// In double-quoted contexts, they are interpreted correctly.
value = strings.ReplaceAll(value, `\`, `\\`)
value = strings.ReplaceAll(value, `"`, `\"`)
return value, nil
}
+15
View File
@@ -163,9 +163,19 @@ groups:
modelLoadingState := false
defaultTimeout := TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
}
expected := Config{
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
@@ -186,6 +196,7 @@ groups:
Name: "Model 1",
Description: "This is model 1",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model2": {
Cmd: "path/to/server --arg1 one",
@@ -194,6 +205,7 @@ groups:
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
@@ -202,6 +214,7 @@ groups:
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
@@ -210,10 +223,12 @@ groups:
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
+783
View File
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfig_GroupMemberIsUnique(t *testing.T) {
@@ -761,3 +762,785 @@ models:
})
}
}
func TestConfig_APIKeys_Invalid(t *testing.T) {
tests := []struct {
name string
content string
expectedErr string
}{
{
name: "empty string",
content: `apiKeys: [""]`,
expectedErr: "empty api key found in apiKeys",
},
{
name: "blank spaces only",
content: `apiKeys: [" "]`,
expectedErr: "api key cannot contain spaces: ` `",
},
{
name: "contains leading space",
content: `apiKeys: [" key123"]`,
expectedErr: "api key cannot contain spaces: ` key123`",
},
{
name: "contains trailing space",
content: `apiKeys: ["key123 "]`,
expectedErr: "api key cannot contain spaces: `key123 `",
},
{
name: "contains middle space",
content: `apiKeys: ["key 123"]`,
expectedErr: "api key cannot contain spaces: `key 123`",
},
{
name: "empty in list with valid keys",
content: `apiKeys: ["valid-key", "", "another-key"]`,
expectedErr: "empty api key found in apiKeys",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
if assert.Error(t, err) {
assert.Equal(t, tt.expectedErr, err.Error())
}
})
}
}
func TestConfig_APIKeys_EnvMacros(t *testing.T) {
t.Run("env substitution in apiKeys", func(t *testing.T) {
t.Setenv("TEST_API_KEY", "secret-key-123")
content := `apiKeys: ["${env.TEST_API_KEY}"]`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, []string{"secret-key-123"}, config.RequiredAPIKeys)
})
t.Run("multiple env substitutions in apiKeys", func(t *testing.T) {
t.Setenv("TEST_API_KEY_1", "key-one")
t.Setenv("TEST_API_KEY_2", "key-two")
content := `apiKeys: ["${env.TEST_API_KEY_1}", "${env.TEST_API_KEY_2}", "static-key"]`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, []string{"key-one", "key-two", "static-key"}, config.RequiredAPIKeys)
})
t.Run("missing env var in apiKeys", func(t *testing.T) {
content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
// With string-level env substitution, error only includes var name
assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY")
})
t.Run("env substitution results in empty key", func(t *testing.T) {
t.Setenv("TEST_EMPTY_KEY", "")
content := `apiKeys: ["${env.TEST_EMPTY_KEY}"]`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Equal(t, "empty api key found in apiKeys", err.Error())
})
}
func TestConfig_GlobalTTL(t *testing.T) {
t.Run("globalTTL sets default for models", func(t *testing.T) {
content := `
globalTTL: 300
models:
model1:
cmd: server --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 300, config.GlobalTTL)
assert.Equal(t, 300, config.Models["model1"].UnloadAfter)
})
t.Run("model ttl=0 overrides globalTTL", func(t *testing.T) {
content := `
globalTTL: 300
models:
model1:
cmd: server --port ${PORT}
ttl: 0
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
})
t.Run("model explicit ttl overrides globalTTL", func(t *testing.T) {
content := `
globalTTL: 300
models:
model1:
cmd: server --port ${PORT}
ttl: 600
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 600, config.Models["model1"].UnloadAfter)
})
t.Run("globalTTL defaults to 0", func(t *testing.T) {
content := `
models:
model1:
cmd: server --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 0, config.GlobalTTL)
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
})
t.Run("negative globalTTL rejected", func(t *testing.T) {
content := `
globalTTL: -1
models:
model1:
cmd: server --port ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "globalTTL must be >= 0")
})
}
func TestConfig_EnvMacros(t *testing.T) {
t.Run("basic env substitution in cmd", func(t *testing.T) {
t.Setenv("TEST_MODEL_PATH", "/opt/models")
content := `
models:
test:
cmd: "${env.TEST_MODEL_PATH}/llama-server"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd)
})
t.Run("env substitution in multiple fields", func(t *testing.T) {
t.Setenv("TEST_HOST", "myserver")
t.Setenv("TEST_PORT", "9999")
content := `
models:
test:
cmd: "server --host ${env.TEST_HOST}"
proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}"
checkEndpoint: "http://${env.TEST_HOST}/health"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "server --host myserver", config.Models["test"].Cmd)
assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy)
assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint)
})
t.Run("env in global macro value", func(t *testing.T) {
t.Setenv("TEST_BASE_PATH", "/usr/local")
content := `
macros:
SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server"
models:
test:
cmd: "${SERVER_PATH} --port 8080"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd)
})
t.Run("env in model-level macro value", func(t *testing.T) {
t.Setenv("TEST_MODEL_DIR", "/models/llama")
content := `
models:
test:
macros:
MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf"
cmd: "server --model ${MODEL_FILE}"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd)
})
t.Run("env in metadata", func(t *testing.T) {
t.Setenv("TEST_API_KEY", "secret123")
content := `
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
metadata:
api_key: "${env.TEST_API_KEY}"
nested:
key: "${env.TEST_API_KEY}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"])
nested := config.Models["test"].Metadata["nested"].(map[string]any)
assert.Equal(t, "secret123", nested["key"])
})
t.Run("env in filters.stripParams", func(t *testing.T) {
t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p")
content := `
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
filters:
stripParams: "${env.TEST_STRIP_PARAMS}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams)
})
t.Run("env in cmdStop", func(t *testing.T) {
t.Setenv("TEST_KILL_SIGNAL", "SIGTERM")
content := `
models:
test:
cmd: "server --port ${PORT}"
cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}"
proxy: "http://localhost:${PORT}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM")
})
t.Run("missing env var returns error", func(t *testing.T) {
content := `
models:
test:
cmd: "${env.UNDEFINED_VAR_12345}/server"
proxy: "http://localhost:8080"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345")
assert.Contains(t, err.Error(), "not set")
}
})
t.Run("missing env var in global macro", func(t *testing.T) {
content := `
macros:
PATH: "${env.UNDEFINED_GLOBAL_VAR}"
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR")
assert.Contains(t, err.Error(), "not set")
}
})
t.Run("missing env var in model macro", func(t *testing.T) {
content := `
models:
test:
macros:
MY_PATH: "${env.UNDEFINED_MODEL_VAR}"
cmd: "server"
proxy: "http://localhost:8080"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR")
assert.Contains(t, err.Error(), "not set")
}
})
t.Run("missing env var in metadata", func(t *testing.T) {
content := `
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
metadata:
key: "${env.UNDEFINED_META_VAR}"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "UNDEFINED_META_VAR")
assert.Contains(t, err.Error(), "not set")
}
})
t.Run("env combined with regular macros", func(t *testing.T) {
t.Setenv("TEST_ROOT", "/data")
content := `
macros:
MODEL_BASE: "${env.TEST_ROOT}/models"
models:
test:
cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd)
})
t.Run("multiple env vars in same string", func(t *testing.T) {
t.Setenv("TEST_USER", "admin")
t.Setenv("TEST_PASS", "secret")
content := `
models:
test:
cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd)
})
t.Run("env value with newline is rejected", func(t *testing.T) {
t.Setenv("TEST_MULTILINE", "line1\nline2")
content := `
models:
test:
cmd: "server --config ${env.TEST_MULTILINE}"
proxy: "http://localhost:8080"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "TEST_MULTILINE")
assert.Contains(t, err.Error(), "newlines")
}
})
t.Run("env value with carriage return is rejected", func(t *testing.T) {
t.Setenv("TEST_CR", "line1\rline2")
content := `
models:
test:
cmd: "server --config ${env.TEST_CR}"
proxy: "http://localhost:8080"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "TEST_CR")
assert.Contains(t, err.Error(), "newlines")
}
})
t.Run("env value with quotes is escaped for YAML", func(t *testing.T) {
t.Setenv("TEST_QUOTED", `value with "quotes"`)
content := `
models:
test:
cmd: "server --arg \"${env.TEST_QUOTED}\""
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Quotes are escaped before YAML parsing, then YAML unescapes them
// Final result preserves the original value with quotes
assert.Contains(t, config.Models["test"].Cmd, `"quotes"`)
})
t.Run("env value with backslash is escaped for YAML", func(t *testing.T) {
t.Setenv("TEST_BACKSLASH", `path\to\file`)
content := `
models:
test:
cmd: "server --path \"${env.TEST_BACKSLASH}\""
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Backslashes are escaped before YAML parsing, then YAML unescapes them
// Final result preserves the original value with backslashes
assert.Contains(t, config.Models["test"].Cmd, `path\to\file`)
})
}
func TestConfig_PeerApiKey_EnvMacros(t *testing.T) {
t.Run("env substitution in peer apiKey", func(t *testing.T) {
t.Setenv("TEST_PEER_API_KEY", "sk-peer-secret-123")
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
apiKey: "${env.TEST_PEER_API_KEY}"
models:
- llama-3.1-8b
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "sk-peer-secret-123", config.Peers["openrouter"].ApiKey)
})
t.Run("missing env var in peer apiKey", func(t *testing.T) {
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
apiKey: "${env.NONEXISTENT_PEER_KEY}"
models:
- llama-3.1-8b
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
// With string-level env substitution, error only includes var name
assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY")
})
t.Run("static apiKey unchanged", func(t *testing.T) {
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
apiKey: sk-static-key
models:
- llama-3.1-8b
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "sk-static-key", config.Peers["openrouter"].ApiKey)
})
t.Run("multiple peers with env apiKeys", func(t *testing.T) {
t.Setenv("TEST_PEER_KEY_1", "key-one")
t.Setenv("TEST_PEER_KEY_2", "key-two")
content := `
peers:
peer1:
proxy: https://peer1.example.com
apiKey: "${env.TEST_PEER_KEY_1}"
models:
- model-a
peer2:
proxy: https://peer2.example.com
apiKey: "${env.TEST_PEER_KEY_2}"
models:
- model-b
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "key-one", config.Peers["peer1"].ApiKey)
assert.Equal(t, "key-two", config.Peers["peer2"].ApiKey)
})
t.Run("global macro substitution in peer apiKey", func(t *testing.T) {
content := `
macros:
API_KEY: sk-from-global-macro
peers:
openrouter:
proxy: https://openrouter.ai/api
apiKey: "${API_KEY}"
models:
- llama-3.1-8b
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "sk-from-global-macro", config.Peers["openrouter"].ApiKey)
})
t.Run("global macro in peer filters.stripParams", func(t *testing.T) {
content := `
macros:
STRIP_LIST: "temperature, top_p"
peers:
openrouter:
proxy: https://openrouter.ai/api
models:
- llama-3.1-8b
filters:
stripParams: "${STRIP_LIST}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "temperature, top_p", config.Peers["openrouter"].Filters.StripParams)
})
t.Run("global macro in peer filters.setParams", func(t *testing.T) {
content := `
macros:
MAX_TOKENS: 4096
peers:
openrouter:
proxy: https://openrouter.ai/api
models:
- llama-3.1-8b
filters:
setParams:
max_tokens: "${MAX_TOKENS}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, 4096, config.Peers["openrouter"].Filters.SetParams["max_tokens"])
})
t.Run("env macro in peer filters.setParams", func(t *testing.T) {
t.Setenv("TEST_RETENTION_POLICY", "deny")
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
models:
- llama-3.1-8b
filters:
setParams:
data_collection: "${env.TEST_RETENTION_POLICY}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "deny", config.Peers["openrouter"].Filters.SetParams["data_collection"])
})
t.Run("env macro in peer filters.stripParams", func(t *testing.T) {
t.Setenv("TEST_STRIP_PARAMS", "frequency_penalty, presence_penalty")
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
models:
- llama-3.1-8b
filters:
stripParams: "${env.TEST_STRIP_PARAMS}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "frequency_penalty, presence_penalty", config.Peers["openrouter"].Filters.StripParams)
})
t.Run("unknown macro in peer apiKey fails", func(t *testing.T) {
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
apiKey: "${UNDEFINED_MACRO}"
models:
- llama-3.1-8b
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peers.openrouter.apiKey")
assert.Contains(t, err.Error(), "unknown macro")
})
t.Run("unknown macro in peer filters.setParams fails", func(t *testing.T) {
content := `
peers:
openrouter:
proxy: https://openrouter.ai/api
models:
- llama-3.1-8b
filters:
setParams:
value: "${UNDEFINED_MACRO}"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peers.openrouter.filters.setParams")
assert.Contains(t, err.Error(), "unknown macro")
})
t.Run("env macros in comments are ignored", func(t *testing.T) {
content := `
# apiKeys:
# - "${env.COMMENTED_OUT_KEY_1}"
# - "${env.COMMENTED_OUT_KEY_2}"
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
`
// These env vars are NOT set, but should not cause an error
// because they only appear in comment lines
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Empty(t, config.RequiredAPIKeys)
})
t.Run("env macros in comments ignored while active ones resolve", func(t *testing.T) {
t.Setenv("TEST_ACTIVE_KEY", "active-key-value")
content := `
# apiKeys: ["${env.COMMENTED_OUT_KEY}"]
apiKeys: ["${env.TEST_ACTIVE_KEY}"]
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, []string{"active-key-value"}, config.RequiredAPIKeys)
})
t.Run("env macros in indented comments are ignored", func(t *testing.T) {
content := `
models:
test:
cmd: |
server
--port 8080
proxy: "http://localhost:8080"
# metadata:
# api_key: "${env.SOME_UNSET_KEY}"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
})
t.Run("env macros in inline comments are ignored", func(t *testing.T) {
t.Setenv("TEST_INLINE_KEY", "real-value")
content := `
apiKeys: ["${env.TEST_INLINE_KEY}"] # TODO: add ${env.FUTURE_KEY} later
models:
test:
cmd: "server"
proxy: "http://localhost:8080"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, []string{"real-value"}, config.RequiredAPIKeys)
})
}
func TestConfig_TimeoutsParsing(t *testing.T) {
configYaml := `
models:
model1:
cmd: test-server --port ${PORT}
timeouts:
connect: 45
responseHeader: 120
`
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
require.NoError(t, err)
modelConfig, found := config.Models["model1"]
require.True(t, found, "model1 should exist in config")
assert.Equal(t, 45, modelConfig.Timeouts.Connect)
assert.Equal(t, 120, modelConfig.Timeouts.ResponseHeader)
}
func TestConfig_TimeoutsDefaults(t *testing.T) {
configYaml := `
models:
model1:
cmd: test-server --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
require.NoError(t, err)
modelConfig, found := config.Models["model1"]
require.True(t, found, "model1 should exist in config")
// Default values should be set during unmarshaling
assert.Equal(t, 30, modelConfig.Timeouts.Connect)
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
assert.Equal(t, 10, modelConfig.Timeouts.TLSHandshake)
assert.Equal(t, 1, modelConfig.Timeouts.ExpectContinue)
assert.Equal(t, 90, modelConfig.Timeouts.IdleConn)
}
func TestConfig_TimeoutsZeroAllowed(t *testing.T) {
configYaml := `
models:
model1:
cmd: test-server --port ${PORT}
timeouts:
connect: 0
responseHeader: 0
`
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
require.NoError(t, err)
modelConfig, found := config.Models["model1"]
require.True(t, found, "model1 should exist in config")
// Explicit 0 should be preserved (disables timeout)
assert.Equal(t, 0, modelConfig.Timeouts.Connect)
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
}
func TestConfig_PeerTimeoutsParsing(t *testing.T) {
configYaml := `
peers:
peer1:
proxy: http://example.com
models: [model1]
timeouts:
connect: 45
responseHeader: 120
`
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
require.NoError(t, err)
peerConfig, found := config.Peers["peer1"]
require.True(t, found, "peer1 should exist in config")
assert.Equal(t, 45, peerConfig.Timeouts.Connect)
assert.Equal(t, 120, peerConfig.Timeouts.ResponseHeader)
}
func TestConfig_PeerTimeoutsDefaults(t *testing.T) {
configYaml := `
peers:
peer1:
proxy: http://example.com
models: [model1]
`
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
require.NoError(t, err)
peerConfig, found := config.Peers["peer1"]
require.True(t, found, "peer1 should exist in config")
// Default values should be set during unmarshaling
assert.Equal(t, 30, peerConfig.Timeouts.Connect)
assert.Equal(t, 60, peerConfig.Timeouts.ResponseHeader)
assert.Equal(t, 10, peerConfig.Timeouts.TLSHandshake)
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
}
+15
View File
@@ -155,9 +155,19 @@ groups:
modelLoadingState := false
defaultTimeout := TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
}
expected := Config{
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800,
Macros: MacroList{
{"svr-path", "path/to/server"},
@@ -172,6 +182,7 @@ groups:
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model2": {
Cmd: "path/to/server --arg1 one",
@@ -181,6 +192,7 @@ groups:
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model3": {
Cmd: "path/to/cmd --arg1 one",
@@ -190,6 +202,7 @@ groups:
Env: []string{},
CheckEndpoint: "/",
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
"model4": {
Cmd: "path/to/cmd --arg1 one",
@@ -199,10 +212,12 @@ groups:
Aliases: []string{},
Env: []string{},
SendLoadingState: &modelLoadingState,
Timeouts: defaultTimeout,
},
},
HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
+114
View File
@@ -0,0 +1,114 @@
package config
import (
"slices"
"sort"
"strings"
)
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
// These are protected to prevent breaking the proxy's ability to route requests correctly
var ProtectedParams = []string{"model"}
// Filters contains filter settings for modifying request parameters
// Used by both models and peers
type Filters struct {
// StripParams is a comma-separated list of parameters to remove from requests
// The "model" parameter can never be removed
StripParams string `yaml:"stripParams"`
// SetParams is a dictionary of parameters to set/override in requests
// Protected params (like "model") cannot be set
SetParams map[string]any `yaml:"setParams"`
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
// Useful with aliases: a single loaded model can behave differently depending on
// which alias the client used. Applied after SetParams, so it can override those values.
// Protected params (like "model") cannot be set.
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
}
// SanitizedStripParams returns a sorted list of parameters to strip,
// with duplicates, empty strings, and protected params removed
func (f Filters) SanitizedStripParams() []string {
if f.StripParams == "" {
return nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
// Skip protected params, empty strings, and duplicates
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
if len(cleaned) == 0 {
return nil
}
slices.Sort(cleaned)
return cleaned
}
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
// with protected params removed and keys sorted for consistent iteration order.
// Returns nil if the ID has no entry or all its params are protected.
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
if len(f.SetParamsByID) == 0 {
return nil, nil
}
params, found := f.SetParamsByID[requestedModelID]
if !found || len(params) == 0 {
return nil, nil
}
result := make(map[string]any, len(params))
keys := make([]string, 0, len(params))
for key, value := range params {
if slices.Contains(ProtectedParams, key) {
continue
}
result[key] = value
keys = append(keys, key)
}
sort.Strings(keys)
if len(result) == 0 {
return nil, nil
}
return result, keys
}
// SanitizedSetParams returns a copy of SetParams with protected params removed
// and keys sorted for consistent iteration order
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
if len(f.SetParams) == 0 {
return nil, nil
}
result := make(map[string]any, len(f.SetParams))
keys := make([]string, 0, len(f.SetParams))
for key, value := range f.SetParams {
// Skip protected params
if slices.Contains(ProtectedParams, key) {
continue
}
result[key] = value
keys = append(keys, key)
}
// Sort keys for consistent ordering
sort.Strings(keys)
if len(result) == 0 {
return nil, nil
}
return result, keys
}
+285
View File
@@ -0,0 +1,285 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilters_SanitizedStripParams(t *testing.T) {
tests := []struct {
name string
stripParams string
want []string
}{
{
name: "empty string",
stripParams: "",
want: nil,
},
{
name: "single param",
stripParams: "temperature",
want: []string{"temperature"},
},
{
name: "multiple params",
stripParams: "temperature, top_p, top_k",
want: []string{"temperature", "top_k", "top_p"}, // sorted
},
{
name: "model param filtered",
stripParams: "model, temperature, top_p",
want: []string{"temperature", "top_p"},
},
{
name: "only model param",
stripParams: "model",
want: nil,
},
{
name: "duplicates removed",
stripParams: "temperature, top_p, temperature",
want: []string{"temperature", "top_p"},
},
{
name: "extra whitespace",
stripParams: " temperature , top_p ",
want: []string{"temperature", "top_p"},
},
{
name: "empty values filtered",
stripParams: "temperature,,top_p,",
want: []string{"temperature", "top_p"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{StripParams: tt.stripParams}
got := f.SanitizedStripParams()
assert.Equal(t, tt.want, got)
})
}
}
func TestFilters_SanitizedSetParams(t *testing.T) {
tests := []struct {
name string
setParams map[string]any
wantParams map[string]any
wantKeys []string
}{
{
name: "empty setParams",
setParams: nil,
wantParams: nil,
wantKeys: nil,
},
{
name: "empty map",
setParams: map[string]any{},
wantParams: nil,
wantKeys: nil,
},
{
name: "normal params",
setParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantKeys: []string{"temperature", "top_p"},
},
{
name: "protected model param filtered",
setParams: map[string]any{
"model": "should-be-filtered",
"temperature": 0.7,
},
wantParams: map[string]any{
"temperature": 0.7,
},
wantKeys: []string{"temperature"},
},
{
name: "only protected param",
setParams: map[string]any{
"model": "should-be-filtered",
},
wantParams: nil,
wantKeys: nil,
},
{
name: "complex nested values",
setParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantKeys: []string{"provider", "transforms"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{SetParams: tt.setParams}
gotParams, gotKeys := f.SanitizedSetParams()
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
for i, key := range gotKeys {
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
}
if tt.wantParams == nil {
assert.Nil(t, gotParams, "expected nil params")
return
}
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
for key, wantValue := range tt.wantParams {
gotValue, exists := gotParams[key]
assert.True(t, exists, "missing key: %s", key)
// Simple comparison for basic types
switch v := wantValue.(type) {
case string, int, float64, bool:
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
}
}
})
}
}
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
tests := []struct {
name string
setParamsByID map[string]map[string]any
requestedModelID string
wantParams map[string]any
wantKeys []string
}{
{
name: "empty SetParamsByID returns nil",
setParamsByID: nil,
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "empty map returns nil",
setParamsByID: map[string]map[string]any{},
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "non-matching model ID returns nil",
setParamsByID: map[string]map[string]any{
"model2": {"temperature": 0.9},
},
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "matching model ID returns correct params",
setParamsByID: map[string]map[string]any{
"model1": {"temperature": 0.7, "top_p": 0.9},
"model2": {"temperature": 0.5},
},
requestedModelID: "model1",
wantParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantKeys: []string{"temperature", "top_p"},
},
{
name: "protected param model is filtered out",
setParamsByID: map[string]map[string]any{
"model1": {
"model": "should-be-filtered",
"temperature": 0.7,
},
},
requestedModelID: "model1",
wantParams: map[string]any{
"temperature": 0.7,
},
wantKeys: []string{"temperature"},
},
{
name: "only protected param returns nil",
setParamsByID: map[string]map[string]any{
"model1": {
"model": "should-be-filtered",
},
},
requestedModelID: "model1",
wantParams: nil,
wantKeys: nil,
},
{
name: "keys are sorted",
setParamsByID: map[string]map[string]any{
"model1": {
"z_param": "z",
"a_param": "a",
"m_param": "m",
},
},
requestedModelID: "model1",
wantParams: map[string]any{
"z_param": "z",
"a_param": "a",
"m_param": "m",
},
wantKeys: []string{"a_param", "m_param", "z_param"},
},
{
name: "alias style key lookup",
setParamsByID: map[string]map[string]any{
"model1:high": {"reasoning_effort": "high"},
"model1:low": {"reasoning_effort": "low"},
},
requestedModelID: "model1:high",
wantParams: map[string]any{
"reasoning_effort": "high",
},
wantKeys: []string{"reasoning_effort"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{SetParamsByID: tt.setParamsByID}
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
if tt.wantParams == nil {
assert.Nil(t, gotParams)
assert.Nil(t, gotKeys)
return
}
assert.Equal(t, tt.wantKeys, gotKeys)
assert.Equal(t, tt.wantParams, gotParams)
})
}
}
func TestProtectedParams(t *testing.T) {
// Verify that "model" is protected
assert.Contains(t, ProtectedParams, "model")
}
+56
View File
@@ -104,6 +104,62 @@ models:
assert.Contains(t, err.Error(), "self-reference")
}
// Test macro substitution in name and description fields
func TestConfig_MacroInNameAndDescription(t *testing.T) {
content := `
startPort: 10000
macros:
"VARIANT": "Q4_K_M"
"FAMILY": "llama"
models:
my-model:
cmd: echo ok
proxy: http://localhost:8080
name: "${FAMILY} ${VARIANT}"
description: "A ${FAMILY} model in ${VARIANT} format"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
}
// Test MODEL_ID macro in name and description fields
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
content := `
startPort: 10000
models:
llama-3b:
cmd: echo ok
proxy: http://localhost:8080
name: "Model: ${MODEL_ID}"
description: "Running ${MODEL_ID}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
}
// Test unknown macro in name or description returns an error
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
content := `
startPort: 10000
models:
test:
cmd: echo ok
proxy: http://localhost:8080
name: "Model ${UNDEFINED}"
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "UNDEFINED")
}
// Test undefined macro reference error
func TestConfig_UndefinedMacroReference(t *testing.T) {
content := `
+36 -28
View File
@@ -3,10 +3,23 @@ package config
import (
"errors"
"runtime"
"slices"
"strings"
)
const (
MODEL_CONFIG_DEFAULT_TTL = -1
)
// TimeoutsConfig holds timeout settings for proxy connections
// 0 = no timeout
type TimeoutsConfig struct {
Connect int `yaml:"connect"`
KeepAlive int `yaml:"keepalive"`
ResponseHeader int `yaml:"responseHeader"`
TLSHandshake int `yaml:"tlsHandshake"`
ExpectContinue int `yaml:"expectContinue"`
IdleConn int `yaml:"idleConn"`
}
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
@@ -38,6 +51,9 @@ type ModelConfig struct {
// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`
// Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
@@ -49,12 +65,22 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
// matches http.DefaultTransport
Timeouts: TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 0,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}
// the default cmdStop to taskkill /f /t /pid ${PID}
@@ -74,16 +100,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
// ModelFilters embeds Filters and adds legacy support for strip_params field
// See issue #174
type ModelFilters struct {
StripParams string `yaml:"stripParams"`
Filters `yaml:",inline"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
defaults := rawModelFilters{}
if err := unmarshal(&defaults); err != nil {
return err
@@ -104,25 +129,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
return nil
}
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
// Returns ([]string, error) to match existing API
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
return f.Filters.SanitizedStripParams(), nil
}
+98
View File
@@ -72,3 +72,101 @@ models:
assert.True(t, *config.Models["model2"].SendLoadingState)
}
}
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
"${MODEL_ID}:high":
reasoning_effort: high
"${MODEL_ID}:low":
reasoning_effort: low
`
cfg, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Keys (other than the model's own ID) should be registered as aliases
realName, found := cfg.RealModelName("model1:high")
assert.True(t, found, "model1:high should be an auto-registered alias")
assert.Equal(t, "model1", realName)
realName, found = cfg.RealModelName("model1:low")
assert.True(t, found, "model1:low should be an auto-registered alias")
assert.Equal(t, "model1", realName)
// Auto-aliases should also appear in modelConfig.Aliases
aliases := cfg.Models["model1"].Aliases
assert.Contains(t, aliases, "model1:high")
assert.Contains(t, aliases, "model1:low")
}
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
model2:
reasoning_effort: high
model2:
cmd: path/to/cmd --port ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.ErrorContains(t, err, "conflicts with an existing model ID")
}
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
"shared-alias":
reasoning_effort: high
model2:
cmd: path/to/cmd --port ${PORT}
filters:
setParamsByID:
"shared-alias":
reasoning_effort: low
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.ErrorContains(t, err, "duplicate alias")
}
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
stripParams: "top_k"
setParams:
temperature: 0.7
top_p: 0.9
stop:
- "<|end|>"
- "<|stop|>"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig := config.Models["model1"]
// Check stripParams
stripParams, err := modelConfig.Filters.SanitizedStripParams()
assert.NoError(t, err)
assert.Equal(t, []string{"top_k"}, stripParams)
// Check setParams
setParams, keys := modelConfig.Filters.SanitizedSetParams()
assert.NotNil(t, setParams)
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
assert.Equal(t, 0.7, setParams["temperature"])
assert.Equal(t, 0.9, setParams["top_p"])
}
+63
View File
@@ -0,0 +1,63 @@
package config
import (
"fmt"
"net/url"
)
type PeerDictionaryConfig map[string]PeerConfig
type PeerConfig struct {
Proxy string `yaml:"proxy"`
ProxyURL *url.URL `yaml:"-"`
ApiKey string `yaml:"apiKey"`
Models []string `yaml:"models"`
Filters Filters `yaml:"filters"`
// Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"`
}
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPeerConfig PeerConfig
defaults := rawPeerConfig{
Proxy: "",
ApiKey: "",
Models: []string{},
Filters: Filters{},
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
// to match the pre PR #619 functionality
Timeouts: TimeoutsConfig{
Connect: 30,
KeepAlive: 30,
ResponseHeader: 60,
TLSHandshake: 10,
ExpectContinue: 1,
IdleConn: 90,
},
}
if err := unmarshal(&defaults); err != nil {
return err
}
// Validate proxy is not empty
if defaults.Proxy == "" {
return fmt.Errorf("proxy is required")
}
// Validate proxy is a valid URL and store the parsed value
parsedURL, err := url.Parse(defaults.Proxy)
if err != nil {
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
}
defaults.ProxyURL = parsedURL
// Validate models is not empty
if len(defaults.Models) == 0 {
return fmt.Errorf("peer models can not be empty")
}
*c = PeerConfig(defaults)
return nil
}
+209
View File
@@ -0,0 +1,209 @@
package config
import (
"testing"
"gopkg.in/yaml.v3"
)
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
tests := []struct {
name string
yaml string
wantErr string
}{
{
name: "valid config",
yaml: `
proxy: http://192.168.1.23
models:
- model_a
- model_b
`,
wantErr: "",
},
{
name: "valid config with apiKey",
yaml: `
proxy: https://openrouter.ai/api
apiKey: sk-test-key
models:
- meta-llama/llama-3.1-8b-instruct
`,
wantErr: "",
},
{
name: "missing proxy",
yaml: `
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "empty proxy",
yaml: `
proxy: ""
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "invalid proxy URL",
yaml: `
proxy: "://invalid"
models:
- model_a
`,
wantErr: "invalid peer proxy URL",
},
{
name: "missing models",
yaml: `
proxy: http://localhost:8080
`,
wantErr: "peer models can not be empty",
},
{
name: "empty models",
yaml: `
proxy: http://localhost:8080
models: []
`,
wantErr: "peer models can not be empty",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var config PeerConfig
err := yaml.Unmarshal([]byte(tt.yaml), &config)
if tt.wantErr == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.wantErr)
} else if !contains(err.Error(), tt.wantErr) {
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
}
}
})
}
}
func TestPeerConfig_ProxyURL(t *testing.T) {
yamlData := `
proxy: http://192.168.1.23:8080/api
apiKey: sk-test
models:
- model_a
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.ProxyURL == nil {
t.Fatal("ProxyURL should not be nil")
}
if config.ProxyURL.Host != "192.168.1.23:8080" {
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
}
if config.ProxyURL.Scheme != "http" {
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
}
if config.ProxyURL.Path != "/api" {
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchSubstring(s, substr)
}
func searchSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestPeerConfig_WithFilters(t *testing.T) {
yamlData := `
proxy: https://openrouter.ai/api
apiKey: sk-test
models:
- model_a
filters:
setParams:
temperature: 0.7
provider:
data_collection: deny
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.Filters.SetParams == nil {
t.Fatal("Filters.SetParams should not be nil")
}
if config.Filters.SetParams["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
}
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
if !ok {
t.Fatal("provider should be a map")
}
if provider["data_collection"] != "deny" {
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
}
}
func TestPeerConfig_WithBothFilters(t *testing.T) {
yamlData := `
proxy: https://openrouter.ai/api
apiKey: sk-test
models:
- model_a
filters:
stripParams: "temperature, top_p"
setParams:
max_tokens: 1000
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check stripParams
stripParams := config.Filters.SanitizedStripParams()
if len(stripParams) != 2 {
t.Errorf("expected 2 strip params, got %d", len(stripParams))
}
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
t.Errorf("unexpected strip params: %v", stripParams)
}
// Check setParams
if config.Filters.SetParams == nil {
t.Fatal("Filters.SetParams should not be nil")
}
if config.Filters.SetParams["max_tokens"] != 1000 {
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
}
}
+9
View File
@@ -8,6 +8,7 @@ const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05
const ModelPreloadedEventID = 0x06
const InFlightRequestsEventID = 0x07
type ProcessStateChangeEvent struct {
ProcessName string
@@ -58,3 +59,11 @@ type ModelPreloadedEvent struct {
func (e ModelPreloadedEvent) Type() uint32 {
return ModelPreloadedEventID
}
type InFlightRequestsEvent struct {
Total int
}
func (e InFlightRequestsEvent) Type() uint32 {
return InFlightRequestsEventID
}
+5 -1
View File
@@ -71,11 +71,15 @@ func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
// Convert path to forward slashes for cross-platform compatibility
// Windows handles forward slashes in paths correctly
cmdPath := filepath.ToSlash(simpleResponderPath)
// Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d"
`, simpleResponderPath, port, expectedMessage, port)
`, cmdPath, port, expectedMessage, port)
var cfg config.ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
+101 -16
View File
@@ -1,7 +1,6 @@
package proxy
import (
"container/ring"
"context"
"fmt"
"io"
@@ -12,6 +11,85 @@ import (
"github.com/mostlygeek/llama-swap/event"
)
// circularBuffer is a fixed-size circular byte buffer that overwrites
// oldest data when full. It provides O(1) writes and O(n) reads.
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
func newCircularBuffer(capacity int) *circularBuffer {
return &circularBuffer{
data: make([]byte, capacity),
head: 0,
size: 0,
}
}
// Write appends bytes to the buffer, overwriting oldest data when full.
// Data is copied into the internal buffer (not stored by reference).
func (cb *circularBuffer) Write(p []byte) {
if len(p) == 0 {
return
}
cap := len(cb.data)
// If input is larger than capacity, only keep the last cap bytes
if len(p) >= cap {
copy(cb.data, p[len(p)-cap:])
cb.head = 0
cb.size = cap
return
}
// Calculate how much space is available from head to end of buffer
firstPart := cap - cb.head
if firstPart >= len(p) {
// All data fits without wrapping
copy(cb.data[cb.head:], p)
cb.head = (cb.head + len(p)) % cap
} else {
// Data wraps around
copy(cb.data[cb.head:], p[:firstPart])
copy(cb.data[:len(p)-firstPart], p[firstPart:])
cb.head = len(p) - firstPart
}
// Update size
cb.size += len(p)
if cb.size > cap {
cb.size = cap
}
}
// GetHistory returns all buffered data in correct order (oldest to newest).
// Returns a new slice (copy), not a view into internal buffer.
func (cb *circularBuffer) GetHistory() []byte {
if cb.size == 0 {
return nil
}
result := make([]byte, cb.size)
cap := len(cb.data)
// Calculate start position (oldest data)
start := (cb.head - cb.size + cap) % cap
if start+cb.size <= cap {
// Data is contiguous, single copy
copy(result, cb.data[start:start+cb.size])
} else {
// Data wraps around, two copies
firstPart := cap - start
copy(result[:firstPart], cb.data[start:])
copy(result[firstPart:], cb.data[:cb.size-firstPart])
}
return result
}
type LogLevel int
const (
@@ -19,12 +97,14 @@ const (
LevelInfo
LevelWarn
LevelError
LogBufferSize = 100 * 1024
)
type LogMonitor struct {
eventbus *event.Dispatcher
mu sync.RWMutex
buffer *ring.Ring
buffer *circularBuffer
bufferMu sync.RWMutex
// typically this can be os.Stdout
@@ -45,7 +125,7 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{
eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
buffer: nil, // lazy initialized on first Write
stdout: stdout,
level: LevelInfo,
prefix: "",
@@ -64,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
}
w.bufferMu.Lock()
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.buffer.Value = bufferCopy
w.buffer = w.buffer.Next()
if w.buffer == nil {
w.buffer = newCircularBuffer(LogBufferSize)
}
w.buffer.Write(p)
w.bufferMu.Unlock()
// Make a copy for broadcast to preserve immutability
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.broadcast(bufferCopy)
return n, nil
}
@@ -77,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
func (w *LogMonitor) GetHistory() []byte {
w.bufferMu.RLock()
defer w.bufferMu.RUnlock()
if w.buffer == nil {
return nil
}
return w.buffer.GetHistory()
}
var history []byte
w.buffer.Do(func(p any) {
if p != nil {
if content, ok := p.([]byte); ok {
history = append(history, content...)
}
}
})
return history
// Clear releases the buffer memory, making it eligible for GC.
// The buffer will be lazily re-allocated on the next Write.
func (w *LogMonitor) Clear() {
w.bufferMu.Lock()
w.buffer = nil
w.bufferMu.Unlock()
}
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
+201
View File
@@ -113,3 +113,204 @@ func TestWrite_LogTimeFormat(t *testing.T) {
t.Fatalf("Cannot find timestamp: %v", err)
}
}
func TestCircularBuffer_WrapAround(t *testing.T) {
// Create a small buffer to test wrap-around
cb := newCircularBuffer(10)
// Write "hello" (5 bytes)
cb.Write([]byte("hello"))
if got := string(cb.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Write "world" (5 bytes) - buffer now full
cb.Write([]byte("world"))
if got := string(cb.GetHistory()); got != "helloworld" {
t.Errorf("Expected 'helloworld', got %q", got)
}
// Write "12345" (5 bytes) - should overwrite "hello"
cb.Write([]byte("12345"))
if got := string(cb.GetHistory()); got != "world12345" {
t.Errorf("Expected 'world12345', got %q", got)
}
// Write data larger than buffer capacity
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
if got := string(cb.GetHistory()); got != "ghijklmnop" {
t.Errorf("Expected 'ghijklmnop', got %q", got)
}
}
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
// Test empty buffer
cb := newCircularBuffer(10)
if got := cb.GetHistory(); got != nil {
t.Errorf("Expected nil for empty buffer, got %q", got)
}
// Test exact capacity
cb.Write([]byte("1234567890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
// Test write exactly at capacity boundary
cb = newCircularBuffer(10)
cb.Write([]byte("12345"))
cb.Write([]byte("67890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
}
func TestLogMonitor_LazyInit(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Buffer should be nil before any writes
if lm.buffer != nil {
t.Error("Expected buffer to be nil before first write")
}
// GetHistory should return nil when buffer is nil
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history before first write, got %q", got)
}
// Write should lazily initialize the buffer
lm.Write([]byte("test"))
if lm.buffer == nil {
t.Error("Expected buffer to be initialized after write")
}
if got := string(lm.GetHistory()); got != "test" {
t.Errorf("Expected 'test', got %q", got)
}
}
func TestLogMonitor_Clear(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write some data
lm.Write([]byte("hello"))
if got := string(lm.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Clear should release the buffer
lm.Clear()
if lm.buffer != nil {
t.Error("Expected buffer to be nil after Clear")
}
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history after Clear, got %q", got)
}
}
func TestLogMonitor_ClearAndReuse(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write, clear, then write again
lm.Write([]byte("first"))
lm.Clear()
lm.Write([]byte("second"))
if got := string(lm.GetHistory()); got != "second" {
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
}
}
func BenchmarkLogMonitorWrite(b *testing.B) {
// Test data of varying sizes
smallMsg := []byte("small message\n")
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
b.Run("SmallWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(smallMsg)
}
})
b.Run("MediumWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("LargeWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(largeMsg)
}
})
b.Run("WithSubscribers", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Add some subscribers
for i := 0; i < 5; i++ {
lm.OnLogData(func(data []byte) {})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("GetHistory", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Pre-populate with data
for i := 0; i < 1000; i++ {
lm.Write(mediumMsg)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.GetHistory()
}
})
}
/*
Benchmark Results - MBP M1 Pro
Before (ring.Ring):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
| MediumWrite (241B) | 76 ns | 264 B | 2 |
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
After (circularBuffer 10KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 67 ns | 240 B | 1 |
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
After (circularBuffer 100KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|-----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 66 ns | 240 B | 1 |
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
Summary:
- GetHistory: 139x faster (10KB), 18x faster (100KB)
- Allocations: reduced from 2 to 1 across all operations
- Small/medium writes: ~1.1-1.6x faster
*/
+292 -34
View File
@@ -2,6 +2,8 @@ package proxy
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"fmt"
"io"
@@ -26,6 +28,28 @@ type TokenMetrics struct {
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"`
DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"`
}
type ReqRespCapture struct {
ID int `json:"id"`
ReqPath string `json:"req_path"`
ReqHeaders map[string]string `json:"req_headers"`
ReqBody []byte `json:"req_body"`
RespHeaders map[string]string `json:"resp_headers"`
RespBody []byte `json:"resp_body"`
}
// Size returns the approximate memory usage of this capture in bytes
func (c *ReqRespCapture) Size() int {
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
for k, v := range c.ReqHeaders {
size += len(k) + len(v)
}
for k, v := range c.RespHeaders {
size += len(k) + len(v)
}
return size
}
// TokenMetricsEvent represents a token metrics event
@@ -44,19 +68,32 @@ type metricsMonitor struct {
maxMetrics int
nextID int
logger *LogMonitor
// capture fields
enableCaptures bool
captures map[int]ReqRespCapture // map for O(1) lookup by ID
captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total size in bytes
maxCaptureSize int // max bytes for captures
}
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
mp := &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
// capture buffer size in megabytes; 0 disables captures.
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
return &metricsMonitor{
logger: logger,
maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0,
captures: make(map[int]ReqRespCapture),
captureOrder: make([]int, 0),
captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024,
}
return mp
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
// addMetrics adds a new metric to the collection and publishes an event.
// Returns the assigned metric ID.
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
mp.mu.Lock()
defer mp.mu.Unlock()
@@ -67,6 +104,49 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
}
event.Emit(TokenMetricsEvent{Metrics: metric})
return metric.ID
}
// addCapture adds a new capture to the buffer with size-based eviction.
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
if !mp.enableCaptures {
return
}
mp.mu.Lock()
defer mp.mu.Unlock()
captureSize := capture.Size()
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return
}
// Evict oldest (FIFO) until room available
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists {
mp.captureSize -= evicted.Size()
delete(mp.captures, oldestID)
}
}
mp.captures[capture.ID] = capture
mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize
}
// getCaptureByID returns a capture by its ID, or nil if not found.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
mp.mu.RLock()
defer mp.mu.RUnlock()
if capture, exists := mp.captures[id]; exists {
return &capture
}
return nil
}
// getMetrics returns a copy of the current metrics
@@ -95,7 +175,35 @@ func (mp *metricsMonitor) wrapHandler(
request *http.Request,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
// Capture request body and headers if captures enabled
var reqBody []byte
var reqHeaders map[string]string
if mp.enableCaptures {
if request.Body != nil {
var err error
reqBody, err = io.ReadAll(request.Body)
if err != nil {
return fmt.Errorf("failed to read request body for capture: %w", err)
}
request.Body.Close()
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
}
reqHeaders = make(map[string]string)
for key, values := range request.Header {
if len(values) > 0 {
reqHeaders[key] = values[0]
}
}
redactHeaders(reqHeaders)
}
recorder := newBodyCopier(writer)
// Filter Accept-Encoding to only include encodings we can decompress for metrics
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
}
if err := next(modelID, recorder, request); err != nil {
return err
}
@@ -108,30 +216,94 @@ func (mp *metricsMonitor) wrapHandler(
return nil
}
// Initialize default metrics - these will always be recorded
tm := TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics skipped, empty body")
mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm)
return nil
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
} else {
// Decompress if needed
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
var err error
body, err = decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm)
return nil
}
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsed
}
} else {
if gjson.ValidBytes(body) {
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
parsed := gjson.ParseBytes(body)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
// extract timings for infill - response is an array, timings are in the last element
// see #463
if strings.HasPrefix(request.URL.Path, "/infill") {
if arr := parsed.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
}
if usage.Exists() || timings.Exists() {
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsedMetrics
}
}
} else {
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
}
}
// Build capture if enabled and determine if it will be stored
var capture *ReqRespCapture
if mp.enableCaptures {
respHeaders := make(map[string]string)
for key, values := range recorder.Header() {
if len(values) > 0 {
respHeaders[key] = values[0]
}
}
redactHeaders(respHeaders)
delete(respHeaders, "Content-Encoding")
capture = &ReqRespCapture{
ReqPath: request.URL.Path,
ReqHeaders: reqHeaders,
ReqBody: reqBody,
RespHeaders: respHeaders,
RespBody: body,
}
// Only set HasCapture if the capture will actually be stored (not too large)
if capture.Size() <= mp.maxCaptureSize {
tm.HasCapture = true
}
}
metricID := mp.addMetrics(tm)
// Store capture if enabled
if capture != nil {
capture.ID = metricID
mp.addCapture(*capture)
}
return nil
}
@@ -174,19 +346,27 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
}
if gjson.ValidBytes(data) {
return parseMetrics(modelID, start, gjson.ParseBytes(data))
parsed := gjson.ParseBytes(data)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
// v1/responses format nests usage under response.usage
if !usage.Exists() {
usage = parsed.Get("response.usage")
}
if usage.Exists() || timings.Exists() {
return parseMetrics(modelID, start, usage, timings)
}
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
}
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
wallDurationMs := int(time.Since(start).Milliseconds())
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
@@ -195,22 +375,41 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(start).Milliseconds())
durationMs := wallDurationMs
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
if pt := usage.Get("prompt_tokens"); pt.Exists() {
// v1/chat/completions
inputTokens = int(pt.Int())
} else if it := usage.Get("input_tokens"); it.Exists() {
// v1/messages
inputTokens = int(it.Int())
}
if ct := usage.Get("completion_tokens"); ct.Exists() {
// v1/chat/completions
outputTokens = int(ct.Int())
} else if ot := usage.Get("output_tokens"); ot.Exists() {
outputTokens = int(ot.Int())
}
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
cachedTokens = int(ct.Int())
}
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
inputTokens = int(timings.Get("prompt_n").Int())
outputTokens = int(timings.Get("predicted_n").Int())
promptPerSecond = timings.Get("prompt_per_second").Float()
tokensPerSecond = timings.Get("predicted_per_second").Float()
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
if timingsDurationMs > durationMs {
durationMs = timingsDurationMs
}
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
@@ -227,6 +426,25 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
}, nil
}
// decompressBody decompresses the body based on Content-Encoding header
func decompressBody(body []byte, encoding string) ([]byte, error) {
switch strings.ToLower(strings.TrimSpace(encoding)) {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
case "deflate":
reader := flate.NewReader(bytes.NewReader(body))
defer reader.Close()
return io.ReadAll(reader)
default:
return body, nil // Return as-is for unknown/no encoding
}
}
// responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
@@ -265,3 +483,43 @@ func (w *responseBodyCopier) Header() http.Header {
func (w *responseBodyCopier) StartTime() time.Time {
return w.start
}
// sensitiveHeaders lists headers that should be redacted in captures
var sensitiveHeaders = map[string]bool{
"authorization": true,
"proxy-authorization": true,
"cookie": true,
"set-cookie": true,
"x-api-key": true,
}
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
func redactHeaders(headers map[string]string) {
for key := range headers {
if sensitiveHeaders[strings.ToLower(key)] {
headers[key] = "[REDACTED]"
}
}
}
// filterAcceptEncoding filters the Accept-Encoding header to only include
// encodings we can decompress (gzip, deflate). This respects the client's
// preferences while ensuring we can parse response bodies for metrics.
func filterAcceptEncoding(acceptEncoding string) string {
if acceptEncoding == "" {
return ""
}
supported := map[string]bool{"gzip": true, "deflate": true}
var filtered []string
for part := range strings.SplitSeq(acceptEncoding, ",") {
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
if supported[strings.ToLower(encoding)] {
filtered = append(filtered, strings.TrimSpace(part))
}
}
return strings.Join(filtered, ", ")
}
+524 -38
View File
@@ -1,6 +1,9 @@
package proxy
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -11,11 +14,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
metric := TokenMetrics{
Model: "test-model",
@@ -34,7 +38,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
})
t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"})
@@ -48,7 +52,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
})
t.Run("respects max metrics limit", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 3)
mm := newMetricsMonitor(testLogger, 3, 0)
// Add 5 metrics
for i := 0; i < 5; i++ {
@@ -68,7 +72,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
})
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
receivedEvent := make(chan TokenMetricsEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) {
@@ -98,14 +102,14 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns empty slice when no metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
metrics := mm.getMetrics()
assert.NotNil(t, metrics)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"})
@@ -125,7 +129,7 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
assert.NotNil(t, jsonData)
@@ -137,7 +141,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
})
t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
mm.addMetrics(TokenMetrics{
Model: "model1",
InputTokens: 100,
@@ -165,7 +169,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
func TestMetricsMonitor_WrapHandler(t *testing.T) {
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{
"usage": {
@@ -196,7 +200,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
})
t.Run("successful request with timings data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{
"timings": {
@@ -236,7 +240,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
})
t.Run("streaming request with SSE format", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
// Note: SSE format requires proper line breaks - each data line followed by blank line
responseBody := `data: {"choices":[{"text":"Hello"}]}
@@ -272,7 +276,7 @@ data: [DONE]
})
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusBadRequest)
@@ -291,8 +295,8 @@ data: [DONE]
assert.Equal(t, 0, len(metrics))
})
t.Run("empty response body does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
t.Run("empty response body records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
@@ -307,11 +311,14 @@ data: [DONE]
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
@@ -328,11 +335,14 @@ data: [DONE]
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("next handler error is propagated", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
expectedErr := assert.AnError
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
@@ -350,8 +360,8 @@ data: [DONE]
assert.Equal(t, 0, len(metrics))
})
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{"result": "ok"}`
@@ -367,10 +377,82 @@ data: [DONE]
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
// Infill response is an array with timings in the last element
responseBody := `[
{"content": "first chunk"},
{"content": "second chunk"},
{"content": "final", "timings": {
"prompt_n": 150,
"predicted_n": 75,
"prompt_per_second": 200.5,
"predicted_per_second": 35.5,
"prompt_ms": 600.0,
"predicted_ms": 1800.0,
"cache_n": 30
}}
]`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/infill", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 150, metrics[0].InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens)
assert.Equal(t, 30, metrics[0].CachedTokens)
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
})
t.Run("infill request with empty array records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `[]`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/infill", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
}
@@ -425,7 +507,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000)
mm := newMetricsMonitor(testLogger, 1000, 0)
var wg sync.WaitGroup
numGoroutines := 10
@@ -452,7 +534,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
})
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 100)
mm := newMetricsMonitor(testLogger, 100, 0)
done := make(chan bool)
@@ -489,8 +571,29 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
}
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
t.Run("keeps wall clock duration when timings underreport request time", func(t *testing.T) {
start := time.Now().Add(-5 * time.Second)
usage := gjson.Parse(`{"prompt_tokens": 5, "completion_tokens": 1}`)
timings := gjson.Parse(`{
"prompt_n": 5,
"predicted_n": 1,
"prompt_per_second": 10.0,
"predicted_per_second": 2.0,
"prompt_ms": 5.0,
"predicted_ms": 15.0
}`)
metrics, err := parseMetrics("test-model", start, usage, timings)
assert.NoError(t, err)
assert.Equal(t, 5, metrics.InputTokens)
assert.Equal(t, 1, metrics.OutputTokens)
assert.Equal(t, 10.0, metrics.PromptPerSecond)
assert.Equal(t, 2.0, metrics.TokensPerSecond)
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
})
t.Run("prefers timings over usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
// Timings should take precedence over usage
responseBody := `{
@@ -530,7 +633,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
})
t.Run("handles missing cache_n in timings", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{
"timings": {
@@ -565,7 +668,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm := newMetricsMonitor(testLogger, 10, 0)
// Metrics should be found in the last data line before [DONE]
responseBody := `data: {"choices":[{"text":"First"}]}
@@ -598,8 +701,8 @@ data: [DONE]
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `data: not json
@@ -619,14 +722,46 @@ data: [DONE]
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("handles empty streaming response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
// v1/responses SSE format: usage is nested under response.usage
responseBody := "event: response.completed\n" +
`data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","created_at":1773416985,"status":"completed","model":"test-model","output":[],"usage":{"input_tokens":17,"output_tokens":23,"total_tokens":40}}}` +
"\n\n"
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/v1/responses", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 17, metrics[0].InputTokens)
assert.Equal(t, 23, metrics[0].OutputTokens)
})
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := ``
@@ -642,17 +777,19 @@ data: [DONE]
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
// Empty body should not trigger WrapHandler processing
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
}
// Benchmark tests
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000)
mm := newMetricsMonitor(testLogger, 1000, 0)
metric := TokenMetrics{
Model: "test-model",
@@ -673,7 +810,7 @@ func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100)
mm := newMetricsMonitor(testLogger, 100, 0)
metric := TokenMetrics{
Model: "test-model",
@@ -691,3 +828,352 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
mm.addMetrics(metric)
}
}
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
t.Run("gzip encoded response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
// Compress with gzip
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
gzWriter.Write([]byte(responseBody))
gzWriter.Close()
compressedBody := buf.Bytes()
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
w.Write(compressedBody)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("deflate encoded response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
// Compress with deflate
var buf bytes.Buffer
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
flateWriter.Write([]byte(responseBody))
flateWriter.Close()
compressedBody := buf.Bytes()
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "deflate")
w.WriteHeader(http.StatusOK)
w.Write(compressedBody)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 200, metrics[0].InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens)
})
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
// Invalid compressed data
invalidData := []byte("this is not gzip data")
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
w.Write(invalidData)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Should not return error, just log warning
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "unknown-encoding")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 300, metrics[0].InputTokens)
assert.Equal(t, 100, metrics[0].OutputTokens)
})
}
func TestReqRespCapture_Size(t *testing.T) {
t.Run("calculates size correctly", func(t *testing.T) {
capture := ReqRespCapture{
ID: 1,
ReqPath: "/v1/chat/completions", // 20 bytes
ReqHeaders: map[string]string{
"Content-Type": "application/json", // 12 + 16 = 28
},
ReqBody: []byte("request body"), // 12 bytes
RespHeaders: map[string]string{
"X-Test": "value", // 6 + 5 = 11
},
RespBody: []byte("response body"), // 13 bytes
}
// Expected: 20 + 12 + 13 + 28 + 11 = 84
assert.Equal(t, 84, capture.Size())
})
t.Run("handles empty capture", func(t *testing.T) {
capture := ReqRespCapture{}
assert.Equal(t, 0, capture.Size())
})
}
func TestMetricsMonitor_AddCapture(t *testing.T) {
t.Run("does nothing when captures disabled", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
capture := ReqRespCapture{
ID: 0,
ReqBody: []byte("test"),
}
mm.addCapture(capture)
// Should not store capture
assert.Nil(t, mm.getCaptureByID(0))
})
t.Run("adds capture when enabled", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
ID: 0,
ReqBody: []byte("test request"),
RespBody: []byte("test response"),
}
mm.addCapture(capture)
retrieved := mm.getCaptureByID(0)
assert.NotNil(t, retrieved)
assert.Equal(t, 0, retrieved.ID)
assert.Equal(t, []byte("test request"), retrieved.ReqBody)
assert.Equal(t, []byte("test response"), retrieved.RespBody)
})
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 // Set small limit for test
// Add captures that will exceed the limit
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)}
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)}
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)}
mm.addCapture(capture1)
mm.addCapture(capture2)
// Adding capture3 should evict capture1
mm.addCapture(capture3)
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
})
t.Run("skips capture larger than max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100
// Add a capture larger than max
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)}
mm.addCapture(largeCapture)
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
})
}
func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
t.Run("returns nil for non-existent ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
assert.Nil(t, mm.getCaptureByID(999))
})
t.Run("returns capture by ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
ID: 42,
ReqBody: []byte("test"),
}
mm.addCapture(capture)
retrieved := mm.getCaptureByID(42)
assert.NotNil(t, retrieved)
assert.Equal(t, 42, retrieved.ID)
})
}
func TestRedactHeaders(t *testing.T) {
t.Run("redacts sensitive headers", func(t *testing.T) {
headers := map[string]string{
"Authorization": "Bearer secret-token",
"Proxy-Authorization": "Basic creds",
"Cookie": "session=abc123",
"Set-Cookie": "session=xyz789",
"X-Api-Key": "sk-12345",
"Content-Type": "application/json",
"X-Custom": "safe-value",
}
redactHeaders(headers)
assert.Equal(t, "[REDACTED]", headers["Authorization"])
assert.Equal(t, "[REDACTED]", headers["Proxy-Authorization"])
assert.Equal(t, "[REDACTED]", headers["Cookie"])
assert.Equal(t, "[REDACTED]", headers["Set-Cookie"])
assert.Equal(t, "[REDACTED]", headers["X-Api-Key"])
assert.Equal(t, "application/json", headers["Content-Type"])
assert.Equal(t, "safe-value", headers["X-Custom"])
})
t.Run("handles mixed case header names", func(t *testing.T) {
headers := map[string]string{
"authorization": "Bearer token",
"COOKIE": "session=abc",
"x-api-key": "key123",
}
redactHeaders(headers)
assert.Equal(t, "[REDACTED]", headers["authorization"])
assert.Equal(t, "[REDACTED]", headers["COOKIE"])
assert.Equal(t, "[REDACTED]", headers["x-api-key"])
})
t.Run("handles empty headers", func(t *testing.T) {
headers := map[string]string{}
redactHeaders(headers)
assert.Empty(t, headers)
})
}
func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
t.Run("captures request and response when enabled", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
requestBody := `{"model": "test", "prompt": "hello"}`
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Custom", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer secret")
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
// Check metric was recorded
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
metricID := metrics[0].ID
// Check capture was stored with same ID
capture := mm.getCaptureByID(metricID)
assert.NotNil(t, capture)
assert.Equal(t, metricID, capture.ID)
assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Equal(t, []byte(responseBody), capture.RespBody)
assert.Equal(t, "/test", capture.ReqPath)
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
})
t.Run("does not capture when disabled", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)
requestBody := `{"model": "test"}`
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
// Metrics should still be recorded
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
// But no capture
capture := mm.getCaptureByID(metrics[0].ID)
assert.Nil(t, capture)
})
}
+143
View File
@@ -0,0 +1,143 @@
package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type peerProxyMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type PeerProxy struct {
peers config.PeerDictionaryConfig
proxyMap map[string]*peerProxyMember
}
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
proxyMap := make(map[string]*peerProxyMember)
// Sort peer IDs for consistent iteration order
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
for _, peerID := range peerIDs {
peer := peers[peerID]
// Create a transport with per-peer timeout configuration
peerTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
}
// Create reverse proxy for this peer
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
reverseProxy.Transport = peerTransport
// Wrap Director to set Host header for remote hosts (not localhost)
originalDirector := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
originalDirector(req)
// Ensure Host header matches target URL for remote proxying
req.Host = req.URL.Host
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerProxyMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
// Map each model to this peer's proxy
for _, modelID := range peer.Models {
if _, found := proxyMap[modelID]; found {
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
proxyMap[modelID] = pp
}
}
return &PeerProxy{
peers: peers,
proxyMap: proxyMap,
}, nil
}
func (p *PeerProxy) HasPeerModel(modelID string) bool {
_, found := p.proxyMap[modelID]
return found
}
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
pp, found := p.proxyMap[modelID]
if !found {
return config.Filters{}
}
// Get the peer config using the peerID
peer, found := p.peers[pp.peerID]
if !found {
return config.Filters{}
}
return peer.Filters
}
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
return p.peers
}
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
pp, found := p.proxyMap[model_id]
if !found {
return fmt.Errorf("no peer proxy found for model %s", model_id)
}
// Inject API key if configured for this peer
if pp.apiKey != "" {
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
request.Header.Set("x-api-key", pp.apiKey)
}
pp.reverseProxy.ServeHTTP(writer, request)
return nil
}
+311
View File
@@ -0,0 +1,311 @@
package proxy
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.NotNil(t, pm)
assert.Empty(t, pm.proxyMap)
}
func TestNewPeerProxy_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 2)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.False(t, pm.HasPeerModel("model-c"))
}
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 4)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.True(t, pm.HasPeerModel("model-c"))
assert.True(t, pm.HasPeerModel("model-d"))
}
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
// should be mapped, and a warning should be logged
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
// Should only have one entry for the duplicate model
assert.Len(t, pm.proxyMap, 1)
assert.True(t, pm.HasPeerModel("duplicate-model"))
}
func TestHasPeerModel(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
Models: []string{"existing-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.True(t, pm.HasPeerModel("existing-model"))
assert.False(t, pm.HasPeerModel("non-existing-model"))
}
func TestProxyRequest_ModelNotFound(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("non-existing-model", w, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
}
func TestProxyRequest_Success(t *testing.T) {
// Create a test server to act as the peer
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "response from peer", w.Body.String())
}
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
}
func TestProxyRequest_NoApiKey(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "", // No API key
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Empty(t, receivedAuthHeader)
}
func TestProxyRequest_HostHeaderSet(t *testing.T) {
// Create a test server that checks the Host header
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The Host header should be set to the target URL's host
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
}
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
// Create a test server that returns SSE content type
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The X-Accel-Buffering header should be set to "no" for SSE
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
}
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
proxyURL, _ := url.Parse("http://localhost:8080")
peers := config.PeerDictionaryConfig{
"test-peer": config.PeerConfig{
Proxy: "http://localhost:8080",
ProxyURL: proxyURL,
Models: []string{"model1"},
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 300,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
},
}
peerProxy, err := NewPeerProxy(peers, testLogger)
assert.NoError(t, err)
assert.NotNil(t, peerProxy)
assert.True(t, peerProxy.HasPeerModel("model1"))
// Verify the timeout values are actually applied to the transport
member, found := peerProxy.proxyMap["model1"]
require.True(t, found, "model1 should exist in proxyMap")
assert.NotNil(t, member.reverseProxy)
assert.NotNil(t, member.reverseProxy.Transport)
transport, ok := member.reverseProxy.Transport.(*http.Transport)
require.True(t, ok, "Transport should be *http.Transport")
// Verify all timeout values are correctly applied
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
// ForceAttemptHTTP2 should be enabled
assert.True(t, transport.ForceAttemptHTTP2)
}
+30 -2
View File
@@ -96,6 +96,24 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
var reverseProxy *httputil.ReverseProxy
if proxyURL != nil {
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
// Create custom transport with configured timeouts
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
}
reverseProxy.Transport = transport
reverseProxy.ModifyResponse = func(resp *http.Response) error {
// prevent nginx from buffering streaming responses (e.g., SSE)
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
@@ -414,6 +432,9 @@ func (p *Process) stopCommand() {
stopStartTime := time.Now()
defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
// free the buffer in processLogger so the memory can be recovered
p.processLogger.Clear()
}()
p.cmdMutex.RLock()
@@ -507,7 +528,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// add a sync so the streaming client only runs when the goroutine has exited
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
// PR #417 (no support for anthropic v1/messages yet)
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
srw = newStatusResponseWriter(p, w)
go srw.statusUpdates(swapCtx)
} else {
@@ -643,6 +667,11 @@ func (p *Process) cmdStopUpstreamProcess() error {
return nil
}
// Logger returns the logger for this process.
func (p *Process) Logger() *LogMonitor {
return p.processLogger
}
var loadingRemarks = []string{
"Still faster than your last standup meeting...",
"Reticulating splines...",
@@ -861,7 +890,6 @@ func (s *statusResponseWriter) WriteHeader(statusCode int) {
s.Flush()
}
// Add Flush method
func (s *statusResponseWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush()
+51 -10
View File
@@ -2,6 +2,7 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -117,12 +118,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
}
expectedMessage := "I_sense_imminent_danger"
config := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
conf := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
conf.UnloadAfter = 3 // seconds
assert.Equal(t, 3, conf.UnloadAfter)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
defer process.Stop()
// this should take 4 seconds
@@ -159,12 +160,12 @@ func TestProcess_LowTTLValue(t *testing.T) {
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
conf := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
conf.UnloadAfter = 1 // second
assert.Equal(t, 1, conf.UnloadAfter)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
@@ -395,6 +396,10 @@ func TestProcess_StopImmediately(t *testing.T) {
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
if runtime.GOOS == "windows" {
t.Skip("skipping SIGTERM test on Windows ")
}
@@ -565,3 +570,39 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
}
return w.ResponseRecorder.Write(b)
}
func TestProcess_CustomTimeouts(t *testing.T) {
modelConfig := config.ModelConfig{
Cmd: "echo test",
Proxy: "http://localhost:8080",
CheckEndpoint: "/health",
Timeouts: config.TimeoutsConfig{
Connect: 45,
ResponseHeader: 120,
TLSHandshake: 15,
ExpectContinue: 2,
IdleConn: 120,
},
}
debugLogger := NewLogMonitorWriter(io.Discard)
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
// Verify the process was created successfully
assert.NotNil(t, process)
assert.Equal(t, "test-model", process.ID)
assert.NotNil(t, process.reverseProxy)
assert.NotNil(t, process.reverseProxy.Transport)
// Verify it's using http.Transport (not some other type)
transport, ok := process.reverseProxy.Transport.(*http.Transport)
assert.True(t, ok, "Transport should be *http.Transport")
assert.NotNil(t, transport)
// Verify the timeouts are correctly applied
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
assert.True(t, transport.ForceAttemptHTTP2)
}
+9 -1
View File
@@ -46,7 +46,8 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u
// Create a Process for each member in the group
for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
pg.processes[modelID] = process
}
@@ -88,6 +89,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
}
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
if pg.HasMember(modelName) {
return pg.processes[modelName], true
}
return nil, false
}
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
pg.Lock()
+4
View File
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
+467 -166
View File
@@ -3,6 +3,7 @@ package proxy
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"mime/multipart"
@@ -27,6 +28,40 @@ const (
type proxyCtxKey string
type InflightCounter struct {
mu sync.Mutex
total int
}
func newInflightCounter() *InflightCounter {
return &InflightCounter{}
}
func (ic *InflightCounter) Current() int {
ic.mu.Lock()
total := ic.total
ic.mu.Unlock()
return total
}
func (ic *InflightCounter) Increment() int {
ic.mu.Lock()
ic.total++
total := ic.total
ic.mu.Unlock()
return total
}
func (ic *InflightCounter) Decrement() int {
ic.mu.Lock()
if ic.total > 0 {
ic.total--
}
total := ic.total
ic.mu.Unlock()
return total
}
type ProxyManager struct {
sync.Mutex
@@ -42,6 +77,8 @@ type ProxyManager struct {
processGroups map[string]*ProcessGroup
inFlightCounter *InflightCounter
// shutdown signaling
shutdownCtx context.Context
shutdownCancel context.CancelFunc
@@ -50,19 +87,42 @@ type ProxyManager struct {
buildDate string
commit string
version string
// peer proxy see: #296, #433
peerProxy *PeerProxy
}
func New(config config.Config) *ProxyManager {
func New(proxyConfig config.Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
switch proxyConfig.LogToStdout {
case config.LogToStdoutNone:
muxLogger = NewLogMonitorWriter(io.Discard)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(io.Discard)
case config.LogToStdoutBoth:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(muxLogger)
case config.LogToStdoutUpstream:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(io.Discard)
default:
// same as config.LogToStdoutProxy
// helpful because some old tests create a config.Config directly and it
// may not have LogToStdout set explicitly
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(muxLogger)
}
if proxyConfig.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
@@ -99,7 +159,7 @@ func New(config config.Config) *ProxyManager {
"stampnano": time.StampNano,
}
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok {
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
proxyLogger.SetLogTimeFormat(timeFormat)
upstreamLogger.SetLogTimeFormat(timeFormat)
}
@@ -107,61 +167,78 @@ func New(config config.Config) *ProxyManager {
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int
if config.MetricsMaxInMemory <= 0 {
if proxyConfig.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback
} else {
maxMetrics = config.MetricsMaxInMemory
maxMetrics = proxyConfig.MetricsMaxInMemory
}
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
if err != nil {
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
peerProxy = nil
}
pm := &ProxyManager{
config: config,
config: proxyConfig,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
muxLogger: muxLogger,
upstreamLogger: upstreamLogger,
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
processGroups: make(map[string]*ProcessGroup),
inFlightCounter: newInflightCounter(),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
buildDate: "unknown",
commit: "abcd1234",
version: "0",
peerProxy: peerProxy,
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
// run any startup hooks
if len(config.Hooks.OnStartup.Preload) > 0 {
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
// do it in the background, don't block startup -- not sure if good idea yet
go func() {
discardWriter := &DiscardWriter{}
for _, realModelName := range config.Hooks.OnStartup.Preload {
proxyLogger.Infof("Preloading model: %s", realModelName)
processGroup, _, err := pm.swapProcessGroup(realModelName)
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
modelID, ok := proxyConfig.RealModelName(preloadModelName)
if !ok {
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
continue
}
proxyLogger.Infof("Preloading model: %s", modelID)
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
ModelName: modelID,
Success: false,
})
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
continue
} else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(realModelName, discardWriter, req)
processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
ModelName: modelID,
Success: true,
})
}
@@ -236,35 +313,50 @@ func (pm *ProxyManager) setupGinEngine() {
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// Support anthropic count_tokens API (Also added in the above PR)
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
// sd.cpp /sdapi/v1 endpoints
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
/**
* User Interface Endpoints
@@ -276,9 +368,9 @@ func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui/models")
})
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
@@ -300,25 +392,35 @@ func (pm *ProxyManager) setupGinEngine() {
if err != nil {
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
} else {
// Serve files with compression support under /ui/*
// This handler checks for pre-compressed .br and .gz files
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
// Default to index.html for directory-like paths
if filepath == "" {
filepath = "index.html"
}
// serve files that exist under /ui/*
pm.ginEngine.StaticFS("/ui", reactFS)
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
})
// server SPA for UI under /ui/*
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
pm.ginEngine.NoRoute(func(c *gin.Context) {
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
c.AbortWithStatus(http.StatusNotFound)
return
}
file, err := reactFS.Open("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
// Check if this looks like a file request (has extension)
path := c.Request.URL.Path
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
// This was likely a file request that wasn't found
c.AbortWithStatus(http.StatusNotFound)
return
}
defer file.Close()
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
// Serve index.html for SPA routing
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
})
}
@@ -330,6 +432,14 @@ func (pm *ProxyManager) setupGinEngine() {
gin.DisableConsoleColor()
}
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
return func(c *gin.Context) {
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
c.Next()
}
}
// ServeHTTP implements http.Handler interface
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
@@ -376,16 +486,10 @@ func (pm *ProxyManager) Shutdown() {
pm.shutdownCancel()
}
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
}
if processGroup.exclusive {
@@ -397,54 +501,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
}
}
return processGroup, realModelName, nil
return processGroup, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := make([]gin.H, 0, len(pm.config.Models))
createdTime := time.Now().Unix()
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
record := gin.H{
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
newRecord := func(modelId string) gin.H {
record := gin.H{
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
data = append(data, newRecord(id))
data = append(data, newRecord(id, modelConfig))
// Include aliases
if pm.config.IncludeAliasesInList {
for _, alias := range modelConfig.Aliases {
if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias))
data = append(data, newRecord(alias, modelConfig))
}
}
}
}
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
// add peer models
for _, modelID := range peer.Models {
// Skip unlisted models if not showing them
record := newRecord(modelID, config.ModelConfig{
Name: fmt.Sprintf("%s: %s", peerID, modelID),
Metadata: map[string]any{
"peerID": peerID,
},
})
data = append(data, record)
}
}
}
// Sort by the "id" key
sort.Slice(data, func(i, j int) bool {
si, _ := data[i]["id"].(string)
@@ -464,62 +585,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
})
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
// split the upstream path by / and search for the model name
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
if len(parts) == 0 {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
modelFound := false
// findModelInPath searches for a valid model name in a path with slashes.
// It iteratively builds up path segments until it finds a matching model.
// Returns: (searchModelName, realModelName, remainingPath, found)
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
searchModelName := ""
var modelName, remainingPath string
for i, part := range parts {
if parts[i] == "" {
if part == "" {
continue
}
if searchModelName == "" {
searchModelName = part
} else {
searchModelName = searchModelName + "/" + parts[i]
searchModelName = searchModelName + "/" + part
}
if real, ok := pm.config.RealModelName(searchModelName); ok {
modelName = real
remainingPath = "/" + strings.Join(parts[i+1:], "/")
modelFound = true
// Check if this is exactly a model name with no additional path
// and doesn't end with a trailing slash
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
// Build new URL with query parameters preserved
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
// Use 308 for non-GET/HEAD requests to preserve method
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
break
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
if !modelFound {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
// This ensures relative URLs in upstream responses resolve correctly and
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
// HTTP method (301 would downgrade to GET).
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
@@ -531,21 +651,21 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return
}
}
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
@@ -558,41 +678,101 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
// Look for a matching local model first
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
// issue #453 set/override parameters in the JSON body
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
for _, key := range setParamsByIDKeys {
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -605,19 +785,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
// issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
@@ -637,9 +817,29 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
// Look for a matching local model first, then check peers
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var useModelName string
modelID, found := pm.config.RealModelName(requestedModel)
if found {
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
useModelName = pm.config.Models[modelID].UseModelName
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
@@ -655,8 +855,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
fieldValue = useModelName
} else {
@@ -726,9 +924,46 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return
}
}
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
requestedModel := c.Query("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
return
}
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
var modelID string
if realModelID, found := pm.config.RealModelName(requestedModel); found {
processGroup, err := pm.swapProcessGroup(realModelID)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
modelID = realModelID
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
modelID = requestedModel
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
return
}
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
return
}
}
@@ -743,6 +978,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
}
}
// apiKeyAuth returns a middleware that validates API keys if configured.
// Returns a pass-through handler if no API keys are configured.
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
if len(pm.config.RequiredAPIKeys) == 0 {
return func(c *gin.Context) { c.Next() }
}
return func(c *gin.Context) {
xApiKey := c.GetHeader("x-api-key")
var bearerKey string
var basicKey string
if auth := c.GetHeader("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ")
} else if strings.HasPrefix(auth, "Basic ") {
// Basic Auth: base64(username:password), password is the API key
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) == 2 {
basicKey = parts[1] // password is the API key
}
}
}
}
// Use first key found: Basic, then Bearer, then x-api-key
var providedKey string
if basicKey != "" {
providedKey = basicKey
} else if bearerKey != "" {
providedKey = bearerKey
} else {
providedKey = xApiKey
}
// Validate key
valid := false
for _, key := range pm.config.RequiredAPIKeys {
if providedKey == key {
valid = true
break
}
}
if !valid {
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
c.Abort()
return
}
// Strip auth headers to prevent leakage to upstream
c.Request.Header.Del("Authorization")
c.Request.Header.Del("x-api-key")
c.Next()
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
@@ -756,8 +1052,13 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
"model": process.ID,
"state": process.state,
"cmd": process.config.Cmd,
"proxy": process.config.Proxy,
"ttl": process.config.UnloadAfter,
"name": process.config.Name,
"description": process.config.Description,
})
}
}
+62 -6
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"github.com/gin-gonic/gin"
@@ -13,22 +14,26 @@ import (
)
type Model struct {
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
Id string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"`
Aliases []string `json:"aliases,omitempty"`
}
func addApiHandlers(pm *ProxyManager) {
// Add API endpoints for React to consume
apiGroup := pm.ginEngine.Group("/api")
// Protected with API key authentication
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
{
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
apiGroup.GET("/version", pm.apiGetVersion)
apiGroup.GET("/captures/:id", pm.apiGetCapture)
}
}
@@ -79,9 +84,22 @@ func (pm *ProxyManager) getModelStatus() []Model {
Description: pm.config.Models[modelID].Description,
State: state,
Unlisted: pm.config.Models[modelID].Unlisted,
Aliases: pm.config.Models[modelID].Aliases,
})
}
// Iterate over the peer models
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
for _, modelID := range peer.Models {
models = append(models, Model{
Id: modelID,
PeerID: peerID,
})
}
}
}
return models
}
@@ -91,6 +109,7 @@ const (
msgTypeModelStatus messageType = "modelStatus"
msgTypeLogData messageType = "logData"
msgTypeMetrics messageType = "metrics"
msgTypeInFlight messageType = "inflight"
)
type messageEnvelope struct {
@@ -150,6 +169,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
}
}
sendInFlight := func(total int) {
jsonData, err := json.Marshal(gin.H{"total": total})
if err == nil {
select {
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
case <-ctx.Done():
return
default:
}
}
}
/**
* Send updated models list
*/
@@ -177,11 +208,19 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
sendMetrics([]TokenMetrics{e.Metrics})
})()
/**
* Send in-flight request stats related to token stats "Waiting: N" count.
*/
defer event.On(func(e InFlightRequestsEvent) {
sendInFlight(e.Total)
})()
// send initial batch of data
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
sendMetrics(pm.metricsMonitor.getMetrics())
sendInFlight(pm.inFlightCounter.Current())
for {
select {
@@ -236,3 +275,20 @@ func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
"build_date": pm.buildDate,
})
}
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
return
}
capture := pm.metricsMonitor.getCaptureByID(id)
if capture == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
return
}
c.JSON(http.StatusOK, capture)
}
+20 -13
View File
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// prevent nginx from buffering streamed logs
c.Header("X-Accel-Buffering", "no")
logMonitorId := c.Param("logMonitorID")
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor
if logMonitorId == "" {
switch logMonitorId {
case "":
// maintain the default
logger = pm.muxLogger
} else if logMonitorId == "proxy" {
logger = pm.proxyLogger
} else if logMonitorId == "upstream" {
logger = pm.upstreamLogger
} else {
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
}
return pm.muxLogger, nil
case "proxy":
return pm.proxyLogger, nil
case "upstream":
return pm.upstreamLogger, nil
default:
// search for a models specific logger using findModelInPath
// to handle model names with slashes (e.g., "author/model")
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
for _, group := range pm.processGroups {
if process, found := group.GetMember(name); found {
return process.Logger(), nil
}
}
}
return logger, nil
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
}
}
+568 -15
View File
@@ -3,6 +3,7 @@ package proxy
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"math/rand"
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
@@ -223,17 +220,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
model2Config.Name = " " // empty whitespace only strings will get ignored
model2Config.Description = " "
config := config.Config{
cfg := config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": getTestSimpleResponderConfig("model3"),
},
Peers: map[string]config.PeerConfig{
"peer1": {
Proxy: "http://peer1:8080",
Models: []string{"peer-model-a", "peer-model-b"},
},
},
LogLevel: "error",
}
proxy := New(config)
proxy := New(cfg)
// Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil)
@@ -258,14 +261,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// Check the number of models returned
assert.Len(t, response.Data, 3)
// Check the number of models returned (3 local + 2 peer models)
assert.Len(t, response.Data, 5)
// Check the details of each model
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
"model3": {},
"model1": {},
"model2": {},
"model3": {},
"peer-model-a": {},
"peer-model-b": {},
}
// make all models
@@ -296,6 +301,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
description, ok := model["description"].(string)
assert.True(t, ok, "description should be a string")
assert.Equal(t, "Model 1 description is used for testing", description)
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
// Peer models should have meta.llamaswap.peerID
meta, exists := model["meta"]
assert.True(t, exists, "peer model should have meta field")
metaMap, ok := meta.(map[string]interface{})
assert.True(t, ok, "meta should be a map")
llamaswap, exists := metaMap["llamaswap"]
assert.True(t, exists, "meta should have llamaswap field")
llamaswapMap, ok := llamaswap.(map[string]interface{})
assert.True(t, ok, "llamaswap should be a map")
peerID, exists := llamaswapMap["peerID"]
assert.True(t, exists, "llamaswap should have peerID field")
assert.Equal(t, "peer1", peerID)
} else {
_, exists := model["name"]
assert.False(t, exists, "unexpected name field for model: %s", modelID)
@@ -502,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
}
func TestProxyManager_Shutdown(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
// make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/"
@@ -650,8 +672,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Define a helper struct to parse the JSON response.
type RunningResponse struct {
Running []struct {
Model string `json:"model"`
State string `json:"state"`
Model string `json:"model"`
State string `json:"state"`
Cmd string `json:"cmd"`
Proxy string `json:"proxy"`
TTL int `json:"ttl"`
Name string `json:"name"`
Description string `json:"description"`
} `json:"running"`
}
@@ -699,6 +726,11 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
// Is the model loaded?
assert.Equal(t, "ready", response.Running[0].State)
// Verify extended fields are present
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
assert.Equal(t, -1, response.Running[0].TTL, "ttl should default to -1 (use globalTTL)")
})
}
@@ -818,6 +850,43 @@ func TestProxyManager_UseModelName(t *testing.T) {
})
}
func TestProxyManager_AudioVoicesGETHandler(t *testing.T) {
conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(conf)
defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("successful GET with model query param", func(t *testing.T) {
req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "voice1")
})
t.Run("missing model query param returns 400", func(t *testing.T) {
req := httptest.NewRequest("GET", "/v1/audio/voices", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
})
t.Run("unknown model returns 400", func(t *testing.T) {
req := httptest.NewRequest("GET", "/v1/audio/voices?model=nonexistent", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "could not find suitable handler")
})
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
@@ -944,7 +1013,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
func TestProxyManager_FiltersStripParams(t *testing.T) {
modelConfig := getTestSimpleResponderConfig("model1")
modelConfig.Filters = config.ModelFilters{
StripParams: "temperature, model, stream",
Filters: config.Filters{
StripParams: "temperature, model, stream",
},
}
config := config.AddDefaultGroupToConfig(config.Config{
@@ -975,6 +1046,61 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
// t.Logf("%v", response)
}
func TestProxyManager_FiltersSetParamsByID(t *testing.T) {
// no explicit aliases — setParamsByID keys are auto-registered as aliases
configStr := strings.Replace(`
logLevel: error
models:
model1:
cmd: 'SRPATH --port ${PORT} --silent --respond model1'
proxy: "http://127.0.0.1:${PORT}"
filters:
setParams:
reasoning_effort: medium
setParamsByID:
"${MODEL_ID}:high":
reasoning_effort: high
"${MODEL_ID}:low":
reasoning_effort: low
`, "SRPATH", simpleResponderPath, -1)
cfg, err := config.LoadConfigFromReader(strings.NewReader(configStr))
if !assert.NoError(t, err, "invalid test configuration") {
return
}
proxy := New(cfg)
defer proxy.StopProcesses(StopWaitForInflightRequest)
tests := []struct {
requestedModel string
wantEffort string
}{
// setParams applies, no setParamsByID match
{requestedModel: "model1", wantEffort: "medium"},
// setParamsByID overrides setParams
{requestedModel: "model1:high", wantEffort: "high"},
{requestedModel: "model1:low", wantEffort: "low"},
}
for _, tt := range tests {
t.Run(tt.requestedModel, func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":%q}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
requestBody, _ := response["request_body"].(string)
gotEffort := gjson.Get(requestBody, "reasoning_effort").String()
assert.Equal(t, tt.wantEffort, gotEffort, "reasoning_effort mismatch for model %s", tt.requestedModel)
})
}
}
func TestProxyManager_HealthEndpoint(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
@@ -1078,7 +1204,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model1": getTestSimpleResponderConfig("model1"),
"author/model": getTestSimpleResponderConfig("author/model"),
},
LogLevel: "error",
})
@@ -1091,6 +1218,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
"/logs/stream",
"/logs/stream/proxy",
"/logs/stream/upstream",
"/logs/stream/author/model",
}
for _, endpoint := range endpoints {
@@ -1185,3 +1313,428 @@ func TestProxyManager_ApiGetVersion(t *testing.T) {
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
}
}
func TestProxyManager_APIKeyAuth(t *testing.T) {
testConfig := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
LogLevel: "error",
})
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
t.Run("valid key in x-api-key header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("Authorization", "Bearer valid-key-2")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("both headers with matching keys", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
req.Header.Set("Authorization", "Bearer valid-key-1")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("invalid key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "invalid-key")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})
t.Run("missing key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
})
t.Run("valid key in Basic Auth header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
// Basic Auth: base64("anyuser:valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
})
}
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
// Config without RequiredAPIKeys - auth should be disabled
testConfig := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
t.Run("requests pass without API key when not configured", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
}
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
// in proxyInferenceHandler for issue #433
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
t.Run("requests to peer models are proxied", func(t *testing.T) {
// Create a test server to act as the peer
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
}))
defer peerServer.Close()
// Create config with peers but no local model for "peer-model"
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"peer-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "from-peer")
})
t.Run("local models take precedence over peer models", func(t *testing.T) {
// Create a test server to act as the peer - should NOT be called
peerCalled := false
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
peerCalled = true
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response":"from-peer"}`))
}))
defer peerServer.Close()
// Create config where "shared-model" exists both locally and on peer
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- shared-model
models:
shared-model:
cmd: %s -port ${PORT} -silent -respond local-response
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"shared-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "local-response")
assert.False(t, peerCalled, "peer should not be called when local model exists")
})
t.Run("unknown model returns error", func(t *testing.T) {
// Create a test server to act as the peer
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer peerServer.Close()
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"unknown-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
})
t.Run("peer API key is injected into request", func(t *testing.T) {
var receivedAuthHeader string
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response":"ok"}`))
}))
defer peerServer.Close()
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
apiKey: secret-peer-key
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"peer-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
})
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
testConfig := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"local-model": getTestSimpleResponderConfig("local-model"),
},
LogLevel: "error",
})
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
// peerProxy exists but has no peer models configured
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
reqBody := `{"model":"unknown-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
})
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte("data: test\n\n"))
}))
defer peerServer.Close()
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"peer-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
})
}
func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) {
conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"sd-model": getTestSimpleResponderConfig("sd-model"),
},
LogLevel: "error",
})
proxy := New(conf)
defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("successful txt2img with model", func(t *testing.T) {
reqBody := `{"model":"sd-model","prompt":"a cat"}`
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "sd-model")
})
t.Run("successful img2img with model", func(t *testing.T) {
reqBody := `{"model":"sd-model","prompt":"a cat","init_images":[]}`
req := httptest.NewRequest("POST", "/sdapi/v1/img2img", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "sd-model")
})
t.Run("missing model returns 400", func(t *testing.T) {
reqBody := `{"prompt":"a cat"}`
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "missing or invalid 'model' key")
})
}
func TestProxyManager_SdApiGetLoras(t *testing.T) {
conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"sd-model": getTestSimpleResponderConfig("sd-model"),
},
LogLevel: "error",
})
proxy := New(conf)
defer proxy.StopProcesses(StopWaitForInflightRequest)
t.Run("successful GET loras with model query param", func(t *testing.T) {
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=sd-model", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("missing model query param returns 400", func(t *testing.T) {
req := httptest.NewRequest("GET", "/sdapi/v1/loras", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
})
t.Run("unknown model returns 400", func(t *testing.T) {
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=nonexistent", nil)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "could not find suitable handler")
})
}
+81
View File
@@ -0,0 +1,81 @@
package proxy
import (
"net/http"
"strings"
)
// selectEncoding chooses the best encoding based on Accept-Encoding header
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
func selectEncoding(acceptEncoding string) (encoding, ext string) {
if acceptEncoding == "" {
return "", ""
}
for _, part := range strings.Split(acceptEncoding, ",") {
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
if enc == "br" {
return "br", ".br"
}
}
for _, part := range strings.Split(acceptEncoding, ",") {
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
if enc == "gzip" {
return "gzip", ".gz"
}
}
return "", ""
}
// ServeCompressedFile serves a file with compression support.
// It checks for pre-compressed versions and serves them with proper headers.
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
// Try to serve compressed version if client supports it
if encoding != "" {
if cf, err := fs.Open(name + ext); err == nil {
defer cf.Close()
// Verify it's a regular file (not a directory)
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
// Set the content encoding header
w.Header().Set("Content-Encoding", encoding)
w.Header().Add("Vary", "Accept-Encoding")
// Get original file info for content type detection
origFile, err := fs.Open(name)
if err == nil {
origFile.Close()
}
// Serve the compressed file
http.ServeContent(w, r, name, stat.ModTime(), cf)
return
}
}
}
// Fall back to serving the uncompressed file
file, err := fs.Open(name)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if stat.IsDir() {
http.Error(w, "is a directory", http.StatusForbidden)
return
}
http.ServeContent(w, r, name, stat.ModTime(), file)
}
+283
View File
@@ -0,0 +1,283 @@
package proxy
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"testing/fstest"
"time"
)
func TestServeCompressedFile_Brotli(t *testing.T) {
// Create test content
content := []byte("This is test content that should be compressed with brotli")
brContent := []byte("fake-brotli-compressed-data")
// Create a test filesystem
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
"test.js.br": {Data: brContent, ModTime: time.Now()},
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
req.Header.Set("Accept-Encoding", "br, gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check that brotli is used (preferred over gzip)
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
}
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
}
if !bytes.Equal(body, brContent) {
t.Errorf("Expected brotli content, got %s", string(body))
}
}
func TestServeCompressedFile_Gzip(t *testing.T) {
// Create test content
content := []byte("This is test content that should be compressed with gzip")
gzContent := []byte("fake-gzip-compressed-data")
// Create a test filesystem without brotli
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
}
if !bytes.Equal(body, gzContent) {
t.Errorf("Expected gzip content, got %s", string(body))
}
}
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
// Create test content
content := []byte("This is uncompressed test content")
// Create a test filesystem without compressed versions
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
req.Header.Set("Accept-Encoding", "br, gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Should not have Content-Encoding header since we're serving uncompressed
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
}
if !bytes.Equal(body, content) {
t.Errorf("Expected original content, got %s", string(body))
}
}
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
// Create test content
content := []byte("This is test content")
// Create a test filesystem with compressed versions
mapFS := fstest.MapFS{
"test.js": {Data: content, ModTime: time.Now()},
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
// No Accept-Encoding header
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "test.js")
resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Should serve uncompressed content
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
}
if !bytes.Equal(body, content) {
t.Errorf("Expected original content, got %s", string(body))
}
}
func TestServeCompressedFile_NotFound(t *testing.T) {
mapFS := fstest.MapFS{}
fs := http.FS(mapFS)
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, "nonexistent.js")
resp := w.Result()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", resp.StatusCode)
}
}
func TestSelectEncoding(t *testing.T) {
tests := []struct {
acceptEncoding string
wantEncoding string
wantExt string
}{
{"br, gzip", "br", ".br"},
{"gzip, deflate", "gzip", ".gz"},
{"gzip", "gzip", ".gz"},
{"br", "br", ".br"},
{"", "", ""},
{"deflate", "", ""},
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
{"browser", "", ""},
{"compress, deflate", "", ""},
}
for _, tt := range tests {
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
}
}
}
// Test with actual pre-compressed files from ui_dist
func TestServeCompressedFile_RealFiles(t *testing.T) {
// Check if ui_dist exists
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
t.Skip("ui_dist not found, skipping real file test")
}
// Find a .js or .css file that has compressed versions
entries, err := os.ReadDir("./ui_dist/assets")
if err != nil {
t.Skipf("Could not read ui_dist/assets: %v", err)
}
var testFile string
for _, entry := range entries {
name := entry.Name()
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
// Check if compressed versions exist
base := strings.TrimSuffix(name, ".js")
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
testFile = "assets/" + name
break
}
}
}
if testFile == "" {
t.Skip("No suitable test file found with compressed versions")
}
fs := http.FS(os.DirFS("./ui_dist"))
// Test brotli
t.Run("brotli", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
req.Header.Set("Accept-Encoding", "br")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, testFile)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
}
})
// Test gzip
t.Run("gzip", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
req.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
ServeCompressedFile(fs, w, req, testFile)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
}
// Verify it's valid gzip
reader, err := gzip.NewReader(resp.Body)
if err != nil {
t.Errorf("Expected valid gzip content: %v", err)
return
}
defer reader.Close()
// Just read to verify it's valid
_, err = io.Copy(io.Discard, reader)
if err != nil {
t.Errorf("Failed to decompress gzip: %v", err)
}
})
}
+2
View File
@@ -0,0 +1,2 @@
node_modules
.vite
+1
View File
@@ -0,0 +1 @@
legacy-peer-deps=true
+3 -3
View File
@@ -10,8 +10,8 @@
<link rel="manifest" href="/site.webmanifest" />
<title>llama-swap</title>
</head>
<body >
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>
+3916
View File
File diff suppressed because it is too large Load Diff
+42
View File
@@ -0,0 +1,42 @@
{
"name": "ui-svelte",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"start": "vite",
"build": "vite build --emptyOutDir",
"preview": "vite preview",
"check": "svelte-check --tsconfig ./tsconfig.json",
"test": "vitest run",
"test:watch": "vitest"
},
"devDependencies": {
"@sveltejs/vite-plugin-svelte": "^7.0.0",
"@tailwindcss/vite": "^4.1.8",
"@tsconfig/svelte": "^5.0.4",
"@types/hast": "^3.0.4",
"@types/node": "^25.1.0",
"svelte": "^5.46.4",
"svelte-check": "^4.1.4",
"tailwindcss": "^4.1.8",
"typescript": "~5.8.3",
"vite": "^8.0.0",
"vite-plugin-compression2": "^2.5.1",
"vitest": "^4.1.0"
},
"dependencies": {
"highlight.js": "^11.11.1",
"katex": "^0.16.28",
"lucide-svelte": "^0.563.0",
"rehype-katex": "^7.0.1",
"rehype-stringify": "^10.0.1",
"remark-gfm": "^4.0.1",
"remark-math": "^6.0.0",
"remark-parse": "^11.0.0",
"remark-rehype": "^11.1.2",
"svelte-spa-router": "^4.0.1",
"unified": "^11.0.5",
"unist-util-visit": "^5.1.0"
}
}

Before

Width:  |  Height:  |  Size: 5.9 KiB

After

Width:  |  Height:  |  Size: 5.9 KiB

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

Before

Width:  |  Height:  |  Size: 38 KiB

After

Width:  |  Height:  |  Size: 38 KiB

Before

Width:  |  Height:  |  Size: 6.5 KiB

After

Width:  |  Height:  |  Size: 6.5 KiB

Before

Width:  |  Height:  |  Size: 28 KiB

After

Width:  |  Height:  |  Size: 28 KiB

+58
View File
@@ -0,0 +1,58 @@
<script lang="ts">
import { onMount } from "svelte";
import Router from "svelte-spa-router";
import Header from "./components/Header.svelte";
import LogViewer from "./routes/LogViewer.svelte";
import Models from "./routes/Models.svelte";
import Activity from "./routes/Activity.svelte";
import Playground from "./routes/Playground.svelte";
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
import { enableAPIEvents } from "./stores/api";
import { initScreenWidth, isDarkMode, appTitle, connectionState } from "./stores/theme";
import { currentRoute } from "./stores/route";
const routes = {
"/": PlaygroundStub,
"/models": Models,
"/logs": LogViewer,
"/activity": Activity,
"*": PlaygroundStub,
};
function handleRouteLoaded(event: { detail: { route: string | RegExp } }) {
const route = event.detail.route;
currentRoute.set(typeof route === "string" ? route : "/");
}
$effect(() => {
document.documentElement.setAttribute("data-theme", $isDarkMode ? "dark" : "light");
});
$effect(() => {
const icon = $connectionState === "connecting" ? "\u{1F7E1}" : $connectionState === "connected" ? "\u{1F7E2}" : "\u{1F534}";
document.title = `${icon} ${$appTitle}`;
});
onMount(() => {
const cleanupScreenWidth = initScreenWidth();
enableAPIEvents(true);
return () => {
cleanupScreenWidth();
enableAPIEvents(false);
};
});
</script>
<div class="flex flex-col h-screen">
<Header />
<main class="flex-1 overflow-auto p-4">
<div class="h-full" class:hidden={$currentRoute !== "/"}>
<Playground />
</div>
<div class="h-full" class:hidden={$currentRoute === "/"}>
<Router {routes} on:routeLoaded={handleRouteLoaded} />
</div>
</main>
</div>

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

Before

Width:  |  Height:  |  Size: 4.0 KiB

After

Width:  |  Height:  |  Size: 4.0 KiB

@@ -0,0 +1,452 @@
<script lang="ts">
import type { ReqRespCapture } from "../lib/types";
interface Props {
capture: ReqRespCapture | null;
open: boolean;
onclose: () => void;
}
let { capture, open, onclose }: Props = $props();
let dialogEl: HTMLDialogElement | undefined = $state();
type BodyTab = "raw" | "pretty" | "chat";
let reqBodyTab: BodyTab = $state("pretty");
let respBodyTab: BodyTab = $state("pretty");
let copiedReq = $state(false);
let copiedResp = $state(false);
$effect(() => {
if (open && dialogEl) {
dialogEl.showModal();
} else if (!open && dialogEl) {
dialogEl.close();
}
});
// Reset tabs when capture changes
$effect(() => {
if (capture) {
const reqCt = getContentType(capture.req_headers);
const respCt = getContentType(capture.resp_headers);
reqBodyTab = reqCt.includes("json") ? "pretty" : "raw";
respBodyTab = respCt.includes("text/event-stream")
? "chat"
: respCt.includes("json")
? "pretty"
: "raw";
}
});
function handleDialogClose() {
onclose();
}
function decodeBody(body: string | null | undefined): string {
if (!body) return "";
try {
const binary = atob(body);
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
return new TextDecoder().decode(bytes);
} catch {
return body;
}
}
function formatJson(str: string): string {
try {
const parsed = JSON.parse(str);
return JSON.stringify(parsed, null, 2);
} catch {
return str;
}
}
function getContentType(
headers: Record<string, string> | null | undefined,
): string {
if (!headers) return "";
const ct = headers["Content-Type"] || headers["content-type"] || "";
return ct.toLowerCase();
}
function isImageContentType(contentType: string): boolean {
return contentType.startsWith("image/");
}
function isTextContentType(contentType: string): boolean {
return (
contentType.startsWith("text/") ||
contentType.includes("application/json") ||
contentType.includes("application/xml") ||
contentType.includes("application/javascript")
);
}
function getImageDataUrl(body: string, contentType: string): string {
const mimeType = contentType.split(";")[0].trim();
return `data:${mimeType};base64,${body}`;
}
interface SSEChat {
reasoning: string;
content: string;
}
function parseSSEChat(text: string): SSEChat {
const result: SSEChat = { reasoning: "", content: "" };
for (const line of text.split("\n")) {
const trimmed = line.trim();
if (!trimmed || !trimmed.startsWith("data: ")) continue;
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
try {
const parsed = JSON.parse(data);
const delta = parsed.choices?.[0]?.delta;
if (delta?.content) result.content += delta.content;
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
} catch {
// skip unparseable lines
}
}
return result;
}
async function copyToClipboard(text: string, type: "req" | "resp") {
try {
await navigator.clipboard.writeText(text);
if (type === "req") {
copiedReq = true;
setTimeout(() => (copiedReq = false), 1500);
} else {
copiedResp = true;
setTimeout(() => (copiedResp = false), 1500);
}
} catch {
// ignore
}
}
function getCopyText(): string {
if (respBodyTab === "chat") {
let text = "";
if (sseChat.reasoning) text += sseChat.reasoning + "\n\n";
text += sseChat.content;
return text;
}
return displayedResponseBody;
}
// Request body derivations
let requestContentType = $derived(
capture ? getContentType(capture.req_headers) : "",
);
let isRequestJson = $derived(requestContentType.includes("json"));
let requestBodyRaw = $derived.by(() => {
if (!capture) return "";
return decodeBody(capture.req_body);
});
let requestBodyPretty = $derived.by(() => {
if (!isRequestJson) return requestBodyRaw;
return formatJson(requestBodyRaw);
});
let displayedRequestBody = $derived(
reqBodyTab === "pretty" ? requestBodyPretty : requestBodyRaw,
);
// Response body derivations
let responseContentType = $derived(
capture ? getContentType(capture.resp_headers) : "",
);
let isResponseImage = $derived(isImageContentType(responseContentType));
let isResponseText = $derived(isTextContentType(responseContentType));
let isResponseJson = $derived(responseContentType.includes("json"));
let isSSE = $derived(responseContentType.includes("text/event-stream"));
let responseBodyRaw = $derived.by(() => {
if (!capture) return "";
return decodeBody(capture.resp_body);
});
let responseBodyPretty = $derived.by(() => {
if (!isResponseJson) return responseBodyRaw;
return formatJson(responseBodyRaw);
});
let sseChat = $derived.by(() => {
if (!isSSE || !responseBodyRaw)
return { reasoning: "", content: "" } as SSEChat;
return parseSSEChat(responseBodyRaw);
});
let displayedResponseBody = $derived.by(() => {
if (respBodyTab === "pretty") return responseBodyPretty;
return responseBodyRaw;
});
</script>
<dialog
bind:this={dialogEl}
onclose={handleDialogClose}
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
>
{#if capture}
<div class="flex flex-col max-h-[90vh]">
<div
class="flex justify-between items-center p-4 border-b border-card-border"
>
<h2 class="text-xl font-bold pb-0">Capture #{capture.id + 1}{#if capture.req_path} <span class="text-base font-mono font-normal text-txtsecondary">{capture.req_path}</span>{/if}</h2>
<button
onclick={() => dialogEl?.close()}
class="text-txtsecondary hover:text-txtmain text-2xl leading-none"
>
&times;
</button>
</div>
<div class="overflow-y-auto flex-1 p-4 space-y-4">
<!-- Request Headers -->
<details class="group" open>
<summary
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
>
Request Headers
</summary>
<div
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
>
<table class="w-full text-sm">
<tbody>
{#each Object.entries(capture.req_headers || {}) as [key, value]}
<tr class="border-b border-card-border-inner last:border-0">
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
>{key}</td
>
<td class="px-3 py-1 font-mono break-all">{value}</td>
</tr>
{/each}
</tbody>
</table>
</div>
</details>
<!-- Request Body -->
<details class="group" open>
<summary
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
>
Request Body
</summary>
{#if requestBodyRaw}
<div class="mt-2 flex items-center justify-between">
<div class="flex gap-1">
{#if isRequestJson}
<button
class="tab-btn"
class:tab-btn-active={reqBodyTab === "pretty"}
onclick={() => (reqBodyTab = "pretty")}>Pretty</button
>
<button
class="tab-btn"
class:tab-btn-active={reqBodyTab === "raw"}
onclick={() => (reqBodyTab = "raw")}>Raw</button
>
{/if}
</div>
<button
class="tab-btn"
onclick={() =>
copyToClipboard(displayedRequestBody, "req")}
>
{#if copiedReq}
Copied!
{:else}
Copy
{/if}
</button>
</div>
<div
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
>
<pre
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedRequestBody}</pre>
</div>
{:else}
<div
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
>
<pre class="p-3 text-sm font-mono whitespace-pre-wrap break-all"
>(empty)</pre
>
</div>
{/if}
</details>
<!-- Response Headers -->
<details class="group" open>
<summary
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
>
Response Headers
</summary>
<div
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
>
<table class="w-full text-sm">
<tbody>
{#each Object.entries(capture.resp_headers || {}) as [key, value]}
<tr class="border-b border-card-border-inner last:border-0">
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
>{key}</td
>
<td class="px-3 py-1 font-mono break-all">{value}</td>
</tr>
{/each}
</tbody>
</table>
</div>
</details>
<!-- Response Body -->
<details class="group" open>
<summary
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
>
Response Body
</summary>
{#if isResponseImage && capture.resp_body}
<div
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
>
<div class="p-3 flex justify-center">
<img
src={getImageDataUrl(capture.resp_body, responseContentType)}
alt="Response"
class="max-w-full h-auto"
/>
</div>
</div>
{:else if isSSE || isResponseText}
<div class="mt-2 flex items-center justify-between">
<div class="flex gap-1">
{#if isSSE}
<button
class="tab-btn"
class:tab-btn-active={respBodyTab === "chat"}
onclick={() => (respBodyTab = "chat")}>Chat</button
>
{/if}
{#if isResponseJson}
<button
class="tab-btn"
class:tab-btn-active={respBodyTab === "pretty"}
onclick={() => (respBodyTab = "pretty")}>Pretty</button
>
{/if}
{#if isSSE || isResponseJson}
<button
class="tab-btn"
class:tab-btn-active={respBodyTab === "raw"}
onclick={() => (respBodyTab = "raw")}>Raw</button
>
{/if}
</div>
<button
class="tab-btn"
onclick={() => copyToClipboard(getCopyText(), "resp")}
>
{#if copiedResp}
Copied!
{:else}
Copy
{/if}
</button>
</div>
<div
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
>
{#if respBodyTab === "chat"}
<div class="p-3 text-sm space-y-3">
{#if sseChat.reasoning}
<div>
<div
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
>
Reasoning
</div>
<pre
class="font-mono whitespace-pre-wrap break-all text-txtsecondary">{sseChat.reasoning}</pre>
</div>
{/if}
{#if sseChat.content}
<div>
{#if sseChat.reasoning}
<div
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
>
Response
</div>
{/if}
<pre
class="font-mono whitespace-pre-wrap break-all">{sseChat.content}</pre>
</div>
{/if}
{#if !sseChat.reasoning && !sseChat.content}
<pre class="font-mono">(empty)</pre>
{/if}
</div>
{:else}
<pre
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedResponseBody || "(empty)"}</pre>
{/if}
</div>
{:else if responseBodyRaw}
<div
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
>
<div class="p-3 text-sm text-txtsecondary italic">
(binary data - {responseContentType || "unknown content type"})
</div>
</div>
{:else}
<div
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
>
<pre class="p-3 text-sm font-mono">(empty)</pre>
</div>
{/if}
</details>
</div>
<div class="p-4 border-t border-card-border flex justify-end">
<button onclick={() => dialogEl?.close()} class="btn"> Close </button>
</div>
</div>
{/if}
</dialog>
<style>
.tab-btn {
padding: 2px 10px;
font-size: 0.75rem;
border-radius: 4px;
color: var(--color-txtsecondary);
cursor: pointer;
border: 1px solid transparent;
background: transparent;
transition: all 0.15s;
}
.tab-btn:hover {
color: var(--color-txtmain);
background: var(--color-secondary);
}
.tab-btn-active {
color: var(--color-primary);
background: color-mix(in srgb, var(--color-primary) 12%, transparent);
border-color: color-mix(in srgb, var(--color-primary) 25%, transparent);
}
</style>
@@ -0,0 +1,24 @@
<script lang="ts">
import { connectionState } from "../stores/theme";
import { versionInfo } from "../stores/api";
let eventStatusColor = $derived.by(() => {
switch ($connectionState) {
case "connected":
return "bg-emerald-500";
case "connecting":
return "bg-amber-500";
case "disconnected":
default:
return "bg-red-500";
}
});
let tooltipText = $derived(
`Event Stream: ${$connectionState ?? "unknown"}\nAPI Version: ${$versionInfo?.version ?? "unknown"}\nCommit Hash: ${$versionInfo?.commit?.substring(0, 7) ?? "unknown"}\nBuild Date: ${$versionInfo?.build_date ?? "unknown"}`
);
</script>
<div class="flex items-center" title={tooltipText}>
<span class="inline-block w-3 h-3 rounded-full {eventStatusColor} mr-2"></span>
</div>
+120
View File
@@ -0,0 +1,120 @@
<script lang="ts">
import { link } from "svelte-spa-router";
import { screenWidth, toggleTheme, isDarkMode, appTitle, isNarrow } from "../stores/theme";
import { currentRoute } from "../stores/route";
import { playgroundActivity } from "../stores/playgroundActivity";
import ConnectionStatus from "./ConnectionStatus.svelte";
function handleTitleChange(newTitle: string): void {
const sanitized = newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap";
appTitle.set(sanitized);
}
function handleKeyDown(e: KeyboardEvent): void {
if (e.key === "Enter") {
e.preventDefault();
const target = e.currentTarget as HTMLElement;
handleTitleChange(target.textContent || "(set title)");
target.blur();
}
}
function handleBlur(e: FocusEvent): void {
const target = e.currentTarget as HTMLElement;
handleTitleChange(target.textContent || "(set title)");
}
function isActive(path: string, current: string): boolean {
return path === "/" ? current === "/" : current.startsWith(path);
}
</script>
<header
class="flex items-center justify-between bg-surface border-b border-border px-4 {$isNarrow
? 'py-1 h-[60px]'
: 'p-2 h-[75px]'}"
>
{#if $screenWidth !== "xs" && $screenWidth !== "sm"}
<h1
contenteditable="true"
class="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
onblur={handleBlur}
onkeydown={handleKeyDown}
>
{$appTitle}
</h1>
{/if}
<menu class="flex items-center gap-4 overflow-x-auto">
<a
href="/"
use:link
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
>
Playground
</a>
<a
href="/models"
use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/models", $currentRoute)}
>
Models
</a>
<a
href="/activity"
use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/activity", $currentRoute)}
>
Activity
</a>
<a
href="/logs"
use:link
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:font-semibold={isActive("/logs", $currentRoute)}
>
Logs
</a>
<button onclick={toggleTheme} title="Toggle theme">
{#if $isDarkMode}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path
fill-rule="evenodd"
d="M9.528 1.718a.75.75 0 0 1 .162.819A8.97 8.97 0 0 0 9 6a9 9 0 0 0 9 9 8.97 8.97 0 0 0 3.463-.69.75.75 0 0 1 .981.98 10.503 10.503 0 0 1-9.694 6.46c-5.799 0-10.5-4.7-10.5-10.5 0-4.368 2.667-8.112 6.46-9.694a.75.75 0 0 1 .818.162Z"
clip-rule="evenodd"
/>
</svg>
{:else}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path
d="M12 2.25a.75.75 0 0 1 .75.75v2.25a.75.75 0 0 1-1.5 0V3a.75.75 0 0 1 .75-.75ZM7.5 12a4.5 4.5 0 1 1 9 0 4.5 4.5 0 0 1-9 0ZM18.894 6.166a.75.75 0 0 0-1.06-1.06l-1.591 1.59a.75.75 0 1 0 1.06 1.061l1.591-1.59ZM21.75 12a.75.75 0 0 1-.75.75h-2.25a.75.75 0 0 1 0-1.5H21a.75.75 0 0 1 .75.75ZM17.834 18.894a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 1 0-1.061 1.06l1.591 1.591ZM12 18a.75.75 0 0 1 .75.75V21a.75.75 0 0 1-1.5 0v-2.25A.75.75 0 0 1 12 18ZM7.758 17.303a.75.75 0 0 0-1.061-1.06l-1.591 1.59a.75.75 0 0 0 1.06 1.061l1.591-1.59ZM6 12a.75.75 0 0 1-.75.75H3a.75.75 0 0 1 0-1.5h2.25A.75.75 0 0 1 6 12ZM6.697 7.757a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 0 0-1.061 1.06l1.59 1.591Z"
/>
</svg>
{/if}
</button>
<ConnectionStatus />
</menu>
</header>
<style>
.activity-link {
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7, #8b5cf6, #6366f1);
background-size: 200% 100%;
-webkit-background-clip: text;
background-clip: text;
-webkit-text-fill-color: transparent;
animation: gradient-shift 2s linear infinite;
}
@keyframes gradient-shift {
0% {
background-position: 0% 50%;
}
100% {
background-position: 200% 50%;
}
}
</style>
+139
View File
@@ -0,0 +1,139 @@
<script lang="ts">
import { persistentStore } from "../stores/persistent";
interface Props {
id: string;
title: string;
logData: string;
}
let { id, title, logData }: Props = $props();
let filterRegex = $state("");
// Create persistent stores for this panel (id is intentionally captured at init time)
// svelte-ignore state_referenced_locally
const fontSizeStore = persistentStore<"xxs" | "xs" | "small" | "normal">(`logPanel-${id}-fontSize`, "normal");
// svelte-ignore state_referenced_locally
const wrapTextStore = persistentStore<boolean>(`logPanel-${id}-wrapText`, false);
// svelte-ignore state_referenced_locally
const showFilterStore = persistentStore<boolean>(`logPanel-${id}-showFilter`, false);
let textWrapClass = $derived($wrapTextStore ? "whitespace-pre-wrap" : "whitespace-pre");
function toggleFontSize(): void {
fontSizeStore.update((prev) => {
switch (prev) {
case "xxs": return "xs";
case "xs": return "small";
case "small": return "normal";
case "normal": return "xxs";
}
});
}
function toggleWrapText(): void {
wrapTextStore.update((prev) => !prev);
}
function toggleFilter(): void {
if ($showFilterStore) {
showFilterStore.set(false);
filterRegex = "";
} else {
showFilterStore.set(true);
}
}
let fontSizeClass = $derived.by(() => {
switch ($fontSizeStore) {
case "xxs": return "text-[0.5rem]";
case "xs": return "text-[0.75rem]";
case "small": return "text-[0.875rem]";
case "normal": return "text-base";
}
});
let filteredLogs = $derived.by(() => {
if (!filterRegex) return logData;
try {
const regex = new RegExp(filterRegex, "i");
return logData.split("\n").filter((line) => regex.test(line)).join("\n");
} catch {
return logData;
}
});
let preElement: HTMLPreElement;
let userScrolledUp = $state(false);
function handleScroll() {
if (!preElement) return;
const { scrollTop, scrollHeight, clientHeight } = preElement;
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
}
// Auto scroll to bottom when logs change, unless user has scrolled up
$effect(() => {
if (preElement && filteredLogs && !userScrolledUp) {
preElement.scrollTop = preElement.scrollHeight;
}
});
</script>
<div class="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full w-full p-1">
<div class="p-4">
<div class="flex items-center justify-between">
<h3 class="m-0 text-lg p-0">{title}</h3>
<div class="flex gap-2 items-center">
<button class="btn border-0" onclick={toggleFontSize} title="Change font size">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
<path d="M2 4v3h5v12h3V7h5V4H2zm19 5h-9v3h3v7h3v-7h3V9z"/>
</svg>
</button>
<button class="btn border-0" onclick={toggleWrapText} title="Toggle text wrap">
{#if $wrapTextStore}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
</svg>
{:else}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h10.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
</svg>
{/if}
</button>
<button class="btn border-0" onclick={toggleFilter} title="Toggle filter">
{#if $showFilterStore}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
<path fill-rule="evenodd" d="M10.5 3.75a6.75 6.75 0 1 0 0 13.5 6.75 6.75 0 0 0 0-13.5ZM2.25 10.5a8.25 8.25 0 1 1 14.59 5.28l4.69 4.69a.75.75 0 1 1-1.06 1.06l-4.69-4.69A8.25 8.25 0 0 1 2.25 10.5Z" clip-rule="evenodd" />
</svg>
{:else}
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-4 h-4">
<path stroke-linecap="round" stroke-linejoin="round" d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z" />
</svg>
{/if}
</button>
</div>
</div>
{#if $showFilterStore}
<div class="mt-2 flex gap-2 items-center w-full">
<input
type="text"
class="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
placeholder="Filter logs (regex)..."
bind:value={filterRegex}
/>
<button class="pl-2" onclick={() => (filterRegex = "")} aria-label="Clear filter">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm-1.72 6.97a.75.75 0 1 0-1.06 1.06L10.94 12l-1.72 1.72a.75.75 0 1 0 1.06 1.06L12 13.06l1.72 1.72a.75.75 0 1 0 1.06-1.06L13.06 12l1.72-1.72a.75.75 0 1 0-1.06-1.06L12 10.94l-1.72-1.72Z" clip-rule="evenodd" />
</svg>
</button>
</div>
{/if}
</div>
<div class="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
<pre bind:this={preElement} onscroll={handleScroll} class="{textWrapClass} {fontSizeClass} h-full overflow-auto p-4">{filteredLogs}</pre>
</div>
</div>
+211
View File
@@ -0,0 +1,211 @@
<script lang="ts">
import { models, loadModel, unloadAllModels, unloadSingleModel } from "../stores/api";
import { isNarrow } from "../stores/theme";
import { persistentStore } from "../stores/persistent";
import type { Model } from "../lib/types";
let isUnloading = $state(false);
let menuOpen = $state(false);
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
let filteredModels = $derived.by(() => {
const filtered = $models.filter((model) => $showUnlistedStore || !model.unlisted);
const peerModels = filtered.filter((m) => m.peerID);
// Group peer models by peerID
const grouped = peerModels.reduce(
(acc, model) => {
const peerId = model.peerID || "unknown";
if (!acc[peerId]) acc[peerId] = [];
acc[peerId].push(model);
return acc;
},
{} as Record<string, Model[]>
);
return {
regularModels: filtered.filter((m) => !m.peerID),
peerModelsByPeerId: grouped,
};
});
async function handleUnloadAllModels(): Promise<void> {
isUnloading = true;
try {
await unloadAllModels();
} catch (e) {
console.error(e);
} finally {
setTimeout(() => (isUnloading = false), 1000);
}
}
function toggleIdorName(): void {
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
}
function toggleShowUnlisted(): void {
showUnlistedStore.update((prev) => !prev);
}
function getModelDisplay(model: Model): string {
return $showIdorNameStore === "id" ? model.id : (model.name || model.id);
}
</script>
<div class="card h-full flex flex-col">
<div class="shrink-0">
<div class="flex justify-between items-baseline">
<h2 class={$isNarrow ? "text-xl" : ""}>Models</h2>
{#if $isNarrow}
<div class="relative">
<button class="btn text-base flex items-center gap-2 py-1" onclick={() => (menuOpen = !menuOpen)} aria-label="Toggle menu">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
</svg>
</button>
{#if menuOpen}
<div class="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
<button
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onclick={() => { toggleIdorName(); menuOpen = false; }}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
</svg>
{$showIdorNameStore === "id" ? "Show Name" : "Show ID"}
</button>
<button
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onclick={() => { toggleShowUnlisted(); menuOpen = false; }}
>
{#if $showUnlistedStore}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
</svg>
{:else}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
</svg>
{/if}
{$showUnlistedStore ? "Hide Unlisted" : "Show Unlisted"}
</button>
<button
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
onclick={() => { handleUnloadAllModels(); menuOpen = false; }}
disabled={isUnloading}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
</svg>
{isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
{/if}
</div>
{/if}
</div>
{#if !$isNarrow}
<div class="flex justify-between">
<div class="flex gap-2">
<button class="btn text-base flex items-center gap-2" onclick={toggleIdorName} style="line-height: 1.2">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
</svg>
{$showIdorNameStore === "id" ? "ID" : "Name"}
</button>
<button class="btn text-base flex items-center gap-2" onclick={toggleShowUnlisted} style="line-height: 1.2">
{#if $showUnlistedStore}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
</svg>
{:else}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
</svg>
{/if}
unlisted
</button>
</div>
<button class="btn text-base flex items-center gap-2" onclick={handleUnloadAllModels} disabled={isUnloading}>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
</svg>
{isUnloading ? "Unloading..." : "Unload All"}
</button>
</div>
{/if}
</div>
<div class="flex-1 overflow-y-auto">
<table class="w-full">
<thead class="sticky top-0 bg-card z-10">
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
<th>{$showIdorNameStore === "id" ? "Model ID" : "Name"}</th>
<th></th>
<th>State</th>
</tr>
</thead>
<tbody>
{#each filteredModels.regularModels as model (model.id)}
<tr class="border-b hover:bg-secondary-hover border-gray-200">
<td class={model.unlisted ? "text-txtsecondary" : ""}>
<a href="/upstream/{model.id}/" class="font-semibold" target="_blank">
{getModelDisplay(model)}
</a>
{#if model.description}
<p class={model.unlisted ? "text-opacity-70" : ""}><em>{model.description}</em></p>
{/if}
{#if model.aliases && model.aliases.length > 0}
<p class="text-xs text-txtsecondary">Aliases: {model.aliases.join(", ")}</p>
{/if}
</td>
<td class="w-12">
{#if model.state === "stopped"}
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
{:else}
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
{/if}
</td>
<td class="w-20">
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
</td>
</tr>
{/each}
</tbody>
</table>
{#if Object.keys(filteredModels.peerModelsByPeerId).length > 0}
<h3 class="mt-8 mb-2">Peer Models</h3>
{#each Object.entries(filteredModels.peerModelsByPeerId).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
<div class="mb-4">
<table class="w-full">
<thead class="sticky top-0 bg-card z-10">
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
<th class="font-semibold">{peerId}</th>
</tr>
</thead>
<tbody>
{#each peerModels as model (model.id)}
<tr class="border-b hover:bg-secondary-hover border-gray-200">
<td class="pl-8 {model.unlisted ? 'text-txtsecondary' : ''}">
<span>{model.id}</span>
</td>
</tr>
{/each}
</tbody>
</table>
</div>
{/each}
{/if}
</div>
</div>
@@ -0,0 +1,152 @@
<script lang="ts">
import type { Snippet } from "svelte";
import { onMount } from "svelte";
interface Props {
direction: "horizontal" | "vertical";
storageKey: string;
leftPanel: Snippet;
rightPanel: Snippet;
defaultSize?: number;
minSize?: number;
}
let { direction, storageKey, leftPanel, rightPanel, defaultSize = 50, minSize = 5 }: Props = $props();
let containerRef: HTMLDivElement;
let isDragging = $state(false);
// svelte-ignore state_referenced_locally
let leftSize = $state(defaultSize);
// Load saved size from localStorage
onMount(() => {
const saved = localStorage.getItem(`panel-size-${storageKey}`);
if (saved) {
const parsed = parseFloat(saved);
if (!isNaN(parsed) && parsed >= minSize && parsed <= 100 - minSize) {
leftSize = parsed;
}
}
});
function saveSize(): void {
localStorage.setItem(`panel-size-${storageKey}`, String(leftSize));
}
function handleMouseDown(e: MouseEvent): void {
e.preventDefault();
isDragging = true;
document.addEventListener("mousemove", handleMouseMove);
document.addEventListener("mouseup", handleMouseUp);
}
function handleTouchStart(_e: TouchEvent): void {
isDragging = true;
document.addEventListener("touchmove", handleTouchMove);
document.addEventListener("touchend", handleTouchEnd);
}
function handleMouseMove(e: MouseEvent): void {
if (!isDragging || !containerRef) return;
updateSize(e.clientX, e.clientY);
}
function handleTouchMove(e: TouchEvent): void {
if (!isDragging || !containerRef || e.touches.length === 0) return;
updateSize(e.touches[0].clientX, e.touches[0].clientY);
}
function updateSize(clientX: number, clientY: number): void {
const rect = containerRef.getBoundingClientRect();
let newSize: number;
if (direction === "horizontal") {
newSize = ((clientX - rect.left) / rect.width) * 100;
} else {
newSize = ((clientY - rect.top) / rect.height) * 100;
}
// Clamp size
newSize = Math.max(minSize, Math.min(100 - minSize, newSize));
leftSize = newSize;
}
function handleMouseUp(): void {
isDragging = false;
saveSize();
document.removeEventListener("mousemove", handleMouseMove);
document.removeEventListener("mouseup", handleMouseUp);
}
function handleTouchEnd(): void {
isDragging = false;
saveSize();
document.removeEventListener("touchmove", handleTouchMove);
document.removeEventListener("touchend", handleTouchEnd);
}
function handleKeyDown(e: KeyboardEvent): void {
const step = 2; // 2% increment for keyboard navigation
const key = e.key;
if (direction === "horizontal" && (key === "ArrowLeft" || key === "ArrowRight")) {
e.preventDefault();
const delta = key === "ArrowLeft" ? -step : step;
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
leftSize = newSize;
saveSize();
} else if (direction === "vertical" && (key === "ArrowUp" || key === "ArrowDown")) {
e.preventDefault();
const delta = key === "ArrowUp" ? -step : step;
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
leftSize = newSize;
saveSize();
}
}
let containerClass = $derived(direction === "horizontal" ? "flex-row" : "flex-col");
let handleClass = $derived(
direction === "horizontal"
? "w-2 h-full cursor-col-resize"
: "w-full h-2 cursor-row-resize"
);
let leftStyle = $derived(
direction === "horizontal"
? `width: ${leftSize}%; min-width: ${minSize}%`
: `height: ${leftSize}%; min-height: ${minSize}%`
);
let rightStyle = $derived(
direction === "horizontal"
? `width: ${100 - leftSize}%; min-width: ${minSize}%`
: `height: ${100 - leftSize}%; min-height: ${minSize}%`
);
</script>
<div bind:this={containerRef} class="flex {containerClass} h-full w-full gap-2">
<div style={leftStyle} class="overflow-hidden">
{@render leftPanel()}
</div>
<!-- svelte-ignore a11y_no_noninteractive_tabindex -->
<!-- svelte-ignore a11y_no_noninteractive_element_interactions -->
<div
role="separator"
tabindex="0"
class="{handleClass} bg-primary hover:bg-success transition-colors rounded flex-shrink-0"
onmousedown={handleMouseDown}
ontouchstart={handleTouchStart}
onkeydown={handleKeyDown}
aria-label="Resize panels"
aria-orientation={direction}
aria-valuenow={Math.round(leftSize)}
aria-valuemin={minSize}
aria-valuemax={100 - minSize}
></div>
<div style={rightStyle} class="overflow-hidden">
{@render rightPanel()}
</div>
</div>
+167
View File
@@ -0,0 +1,167 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import TokenHistogram from "./TokenHistogram.svelte";
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
if (totalRequests === 0) {
return {
totalRequests: 0,
totalInputTokens: 0,
totalOutputTokens: 0,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
// Calculate token statistics using output_tokens and duration_ms
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
if (validMetrics.length === 0) {
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
// Calculate tokens/second for each valid metric
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
// Sort for percentile calculation
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
// Create histogram data
const min = Math.min(...tokensPerSecond);
const max = Math.max(...tokensPerSecond);
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
const binSize = (max - min) / binCount;
const bins = Array(binCount).fill(0);
tokensPerSecond.forEach((value) => {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
});
const histogramData: HistogramData = {
bins,
min,
max,
binSize,
p99,
p95,
p50,
};
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: {
p99: p99.toFixed(2),
p95: p95.toFixed(2),
p50: p50.toFixed(2),
},
histogramData,
};
});
const nf = new Intl.NumberFormat();
</script>
<div class="card">
<div class="rounded-lg overflow-hidden border border-card-border-inner">
<table class="min-w-full divide-y divide-card-border-inner">
<thead class="bg-secondary">
<tr>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Processed
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Generated
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Token Stats (tokens/sec)
</th>
</tr>
</thead>
<tbody class="bg-surface divide-y divide-card-border-inner">
<tr class="hover:bg-secondary">
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
<div class="flex flex-col gap-1">
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
<div class="space-y-3">
<div class="grid grid-cols-3 gap-2 items-center">
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p50}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p95}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p99}
</div>
</div>
</div>
{#if stats.histogramData}
<TokenHistogram data={stats.histogramData} />
{/if}
</div>
</td>
</tr>
</tbody>
</table>
</div>
</div>
@@ -0,0 +1,129 @@
<script lang="ts">
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
interface Props {
data: HistogramData;
}
let { data }: Props = $props();
const height = 120;
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
const viewBoxWidth = 600;
const chartWidth = viewBoxWidth - padding.left - padding.right;
const chartHeight = height - padding.top - padding.bottom;
let maxCount = $derived(Math.max(...data.bins));
let barWidth = $derived(chartWidth / data.bins.length);
let range = $derived(data.max - data.min);
function getXPosition(value: number): number {
return padding.left + ((value - data.min) / range) * chartWidth;
}
</script>
<div class="mt-2 w-full">
<svg viewBox="0 0 {viewBoxWidth} {height}" class="w-full h-auto" preserveAspectRatio="xMidYMid meet">
<!-- Y-axis -->
<line
x1={padding.left}
y1={padding.top}
x2={padding.left}
y2={height - padding.bottom}
stroke="currentColor"
stroke-width="1"
opacity="0.3"
/>
<!-- X-axis -->
<line
x1={padding.left}
y1={height - padding.bottom}
x2={viewBoxWidth - padding.right}
y2={height - padding.bottom}
stroke="currentColor"
stroke-width="1"
opacity="0.3"
/>
<!-- Histogram bars -->
{#each data.bins as count, i}
{@const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0}
{@const x = padding.left + i * barWidth}
{@const y = height - padding.bottom - barHeight}
{@const binStart = data.min + i * data.binSize}
{@const binEnd = binStart + data.binSize}
<g>
<rect
{x}
{y}
width={Math.max(barWidth - 1, 1)}
height={barHeight}
fill="currentColor"
opacity="0.6"
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
/>
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
</g>
{/each}
<!-- Percentile lines -->
<line
x1={getXPosition(data.p50)}
y1={padding.top}
x2={getXPosition(data.p50)}
y2={height - padding.bottom}
stroke="currentColor"
stroke-width="2"
stroke-dasharray="4 2"
opacity="0.7"
class="text-gray-600 dark:text-gray-400"
/>
<line
x1={getXPosition(data.p95)}
y1={padding.top}
x2={getXPosition(data.p95)}
y2={height - padding.bottom}
stroke="currentColor"
stroke-width="2"
stroke-dasharray="4 2"
opacity="0.7"
class="text-orange-500 dark:text-orange-400"
/>
<line
x1={getXPosition(data.p99)}
y1={padding.top}
x2={getXPosition(data.p99)}
y2={height - padding.bottom}
stroke="currentColor"
stroke-width="2"
stroke-dasharray="4 2"
opacity="0.7"
class="text-green-500 dark:text-green-400"
/>
<!-- X-axis labels -->
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start">
{data.min.toFixed(1)}
</text>
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end">
{data.max.toFixed(1)}
</text>
<!-- X-axis label -->
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
Tokens/Second Distribution
</text>
</svg>
</div>
+20
View File
@@ -0,0 +1,20 @@
<script lang="ts">
interface Props {
content: string;
}
let { content }: Props = $props();
</script>
<div class="relative group inline-block">
<span class="cursor-help">&#9432;</span>
<div
class="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
opacity-0 group-hover:opacity-100 transition-opacity
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
>
{content}
<div class="absolute bottom-full left-1/2 transform -translate-x-1/2 border-4 border-transparent border-b-gray-900"></div>
</div>
</div>
@@ -0,0 +1,256 @@
<script lang="ts">
import { models } from "../../stores/api";
import { persistentStore } from "../../stores/persistent";
import { transcribeAudio } from "../../lib/audioApi";
import { playgroundStores } from "../../stores/playgroundActivity";
import ModelSelector from "./ModelSelector.svelte";
const selectedModelStore = persistentStore<string>("playground-audio-model", "");
let selectedFile = $state<File | null>(null);
let isTranscribing = $state(false);
let transcriptionResult = $state<string | null>(null);
let error = $state<string | null>(null);
let abortController = $state<AbortController | null>(null);
let isDragging = $state(false);
let fileInput = $state<HTMLInputElement | null>(null);
let copied = $state(false);
const ACCEPTED_FORMATS = ['.mp3', '.wav', '.ogg'];
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB
let hasModels = $derived($models.some((m) => !m.unlisted));
let canTranscribe = $derived(selectedFile !== null && $selectedModelStore !== "" && !isTranscribing);
$effect(() => {
playgroundStores.audioTranscribing.set(isTranscribing);
});
function validateFile(file: File): { valid: boolean; error?: string } {
const ext = '.' + file.name.split('.').pop()?.toLowerCase();
if (!ACCEPTED_FORMATS.includes(ext)) {
return { valid: false, error: 'Invalid file type. Accepted: MP3, WAV, OGG' };
}
if (file.size > MAX_FILE_SIZE) {
return { valid: false, error: 'File too large. Maximum: 25MB' };
}
return { valid: true };
}
function handleFileSelect(event: Event) {
const target = event.target as HTMLInputElement;
const file = target.files?.[0];
if (file) {
const validation = validateFile(file);
if (validation.valid) {
selectedFile = file;
error = null;
transcriptionResult = null;
} else {
error = validation.error || "Invalid file";
selectedFile = null;
}
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
isDragging = true;
}
function handleDragLeave() {
isDragging = false;
}
function handleDrop(event: DragEvent) {
event.preventDefault();
isDragging = false;
const file = event.dataTransfer?.files[0];
if (file) {
const validation = validateFile(file);
if (validation.valid) {
selectedFile = file;
error = null;
transcriptionResult = null;
} else {
error = validation.error || "Invalid file";
selectedFile = null;
}
}
}
async function transcribe() {
if (!selectedFile || !$selectedModelStore || isTranscribing) return;
isTranscribing = true;
error = null;
transcriptionResult = null;
abortController = new AbortController();
try {
const response = await transcribeAudio(
$selectedModelStore,
selectedFile,
abortController.signal
);
transcriptionResult = response.text;
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
// User cancelled
} else {
error = err instanceof Error ? err.message : "An error occurred";
}
} finally {
isTranscribing = false;
abortController = null;
}
}
function cancelTranscription() {
abortController?.abort();
}
function clearAll() {
selectedFile = null;
transcriptionResult = null;
error = null;
if (fileInput) {
fileInput.value = '';
}
}
function copyToClipboard() {
if (transcriptionResult) {
navigator.clipboard.writeText(transcriptionResult);
copied = true;
setTimeout(() => {
copied = false;
}, 2000);
}
}
function formatFileSize(bytes: number): string {
if (bytes < 1024) return bytes + ' B';
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
}
</script>
<div class="flex flex-col h-full">
<!-- Model selector -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} />
</div>
<!-- Empty state for no models configured -->
{#if !hasModels}
<div class="flex-1 flex items-center justify-center text-txtsecondary">
<p>No models configured. Add models to your configuration to transcribe audio.</p>
</div>
{:else}
<!-- File upload / Result display area -->
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
{#if isTranscribing}
<div class="text-center text-txtsecondary">
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
<p>Transcribing audio...</p>
</div>
{:else if error}
<div class="text-center text-red-500 p-4">
<p class="font-medium">Error</p>
<p class="text-sm mt-1">{error}</p>
</div>
{:else if transcriptionResult}
<div class="w-full h-full flex flex-col p-4">
<div class="flex justify-between items-center mb-2">
<h3 class="font-medium">Transcription Result</h3>
<button
class="btn btn-sm"
onclick={copyToClipboard}
title={copied ? 'Copied!' : 'Copy to clipboard'}
>
{#if copied}
<svg class="w-5 h-5 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
</svg>
{:else}
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path>
</svg>
{/if}
</button>
</div>
<div class="flex-1 overflow-auto p-3 rounded border border-gray-200 dark:border-white/10 bg-background whitespace-pre-wrap">
{transcriptionResult}
</div>
</div>
{:else if selectedFile}
<div class="text-center text-txtsecondary p-4">
<p class="font-medium mb-2">File Selected</p>
<p class="text-sm">{selectedFile.name}</p>
<p class="text-xs mt-1">{formatFileSize(selectedFile.size)}</p>
</div>
{:else}
<div
role="region"
aria-label="Audio file drop zone"
class="w-full h-full flex items-center justify-center text-center text-txtsecondary p-8 {isDragging ? 'bg-primary/10' : ''}"
ondragover={handleDragOver}
ondragleave={handleDragLeave}
ondrop={handleDrop}
>
<div>
<p class="mb-2">Drag and drop an audio file here</p>
<p class="text-sm">or use the Browse button below</p>
<p class="text-xs mt-4">Accepted formats: MP3, WAV, OGG (max 25MB)</p>
</div>
</div>
{/if}
</div>
<!-- File input and transcribe button -->
<div class="shrink-0 flex gap-2">
<input
type="file"
accept=".mp3,.wav,.ogg"
class="hidden"
onchange={handleFileSelect}
bind:this={fileInput}
/>
<button
class="btn"
onclick={() => fileInput?.click()}
disabled={isTranscribing}
>
Browse Files
</button>
<div class="flex-1"></div>
{#if isTranscribing}
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelTranscription}>
Cancel
</button>
{:else}
<button
class="btn bg-primary text-btn-primary-text hover:opacity-90"
onclick={transcribe}
disabled={!canTranscribe}
>
Transcribe
</button>
<button
class="btn"
onclick={clearAll}
disabled={!selectedFile && !transcriptionResult && !error}
>
Clear
</button>
{/if}
</div>
{/if}
</div>
@@ -0,0 +1,466 @@
<script lang="ts">
import { models } from "../../stores/api";
import { persistentStore } from "../../stores/persistent";
import { streamChatCompletion } from "../../lib/chatApi";
import { playgroundStores } from "../../stores/playgroundActivity";
import type { ChatMessage, ContentPart } from "../../lib/types";
import ChatMessageComponent from "./ChatMessage.svelte";
import ModelSelector from "./ModelSelector.svelte";
import ExpandableTextarea from "./ExpandableTextarea.svelte";
const selectedModelStore = persistentStore<string>("playground-selected-model", "");
const systemPromptStore = persistentStore<string>("playground-system-prompt", "");
const temperatureStore = persistentStore<number>("playground-temperature", 0.7);
function loadMessages(): ChatMessage[] {
try {
const saved = localStorage.getItem("playground-messages");
return saved ? JSON.parse(saved) : [];
} catch {
return [];
}
}
let messages = $state<ChatMessage[]>(loadMessages());
let userInput = $state("");
let isStreaming = $state(false);
let isReasoning = $state(false);
let reasoningStartTime = $state<number>(0);
let abortController = $state<AbortController | null>(null);
let messagesContainer: HTMLDivElement | undefined = $state();
let showSettings = $state(false);
let attachedImages = $state<string[]>([]);
let fileInput = $state<HTMLInputElement | null>(null);
let imageError = $state<string | null>(null);
let hasModels = $derived($models.some((m) => !m.unlisted));
let userScrolledUp = $state(false);
$effect(() => {
playgroundStores.chatStreaming.set(isStreaming);
});
function handleMessagesScroll() {
if (!messagesContainer) return;
const { scrollTop, scrollHeight, clientHeight } = messagesContainer;
// Consider "at bottom" if within 40px of the bottom
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
}
// Auto-scroll when messages change — skip if user scrolled up
$effect(() => {
if (messages.length > 0 && messagesContainer && !userScrolledUp) {
messagesContainer.scrollTo({
top: messagesContainer.scrollHeight,
behavior: isStreaming ? "instant" : "smooth",
});
}
});
// Persist messages to localStorage (throttled to once per 2s)
let lastSaveTime = 0;
$effect(() => {
const json = JSON.stringify(messages);
const elapsed = Date.now() - lastSaveTime;
const save = () => {
try { localStorage.setItem("playground-messages", json); } catch {}
lastSaveTime = Date.now();
};
if (elapsed >= 2000) {
save();
return;
}
const timer = setTimeout(save, 2000 - elapsed);
return () => clearTimeout(timer);
});
async function sendMessage() {
const trimmedInput = userInput.trim();
if ((!trimmedInput && attachedImages.length === 0) || !$selectedModelStore || isStreaming) return;
userScrolledUp = false;
// Build message content (multimodal if images attached)
let content: string | ContentPart[];
if (attachedImages.length > 0) {
const parts: ContentPart[] = [];
if (trimmedInput) {
parts.push({ type: "text", text: trimmedInput });
}
for (const url of attachedImages) {
parts.push({ type: "image_url", image_url: { url } });
}
content = parts;
} else {
content = trimmedInput;
}
// Add user message
messages = [...messages, { role: "user", content }];
userInput = "";
attachedImages = [];
imageError = null;
// Generate response from the new user message
await regenerateFromIndex(messages.length - 1);
}
function cancelStreaming() {
abortController?.abort();
}
function newChat() {
if (isStreaming) {
cancelStreaming();
}
messages = [];
isReasoning = false;
reasoningStartTime = 0;
}
async function regenerateFromIndex(idx: number) {
// Remove all messages after the edited user message
messages = messages.slice(0, idx + 1);
// Add empty assistant message for the new response
messages = [...messages, { role: "assistant", content: "" }];
isStreaming = true;
isReasoning = false;
reasoningStartTime = 0;
abortController = new AbortController();
try {
// Build messages array with optional system prompt
const apiMessages: ChatMessage[] = [];
if ($systemPromptStore.trim()) {
apiMessages.push({ role: "system", content: $systemPromptStore.trim() });
}
apiMessages.push(...messages.slice(0, -1)); // Add all messages except the empty assistant one
const stream = streamChatCompletion(
$selectedModelStore,
apiMessages,
abortController.signal,
{ temperature: $temperatureStore }
);
for await (const chunk of stream) {
if (chunk.done) break;
// Handle reasoning content
if (chunk.reasoning_content) {
// Start timing on first reasoning content
if (!isReasoning) {
isReasoning = true;
reasoningStartTime = Date.now();
}
// Update the last message with reasoning content
messages = messages.map((msg, i) =>
i === messages.length - 1
? { ...msg, reasoning_content: (msg.reasoning_content || "") + chunk.reasoning_content }
: msg
);
}
// Handle regular content - end reasoning phase when we get content
if (chunk.content) {
if (isReasoning) {
// Calculate reasoning time
const reasoningTimeMs = Date.now() - reasoningStartTime;
isReasoning = false;
// Update message with reasoning time
messages = messages.map((msg, i) =>
i === messages.length - 1
? { ...msg, reasoningTimeMs }
: msg
);
}
// Update the last message (assistant) with new content
messages = messages.map((msg, i) =>
i === messages.length - 1
? { ...msg, content: msg.content + chunk.content }
: msg
);
}
}
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
// User cancelled, keep partial response
// If we were still reasoning, record the time
if (isReasoning && reasoningStartTime > 0) {
const reasoningTimeMs = Date.now() - reasoningStartTime;
messages = messages.map((msg, i) =>
i === messages.length - 1
? { ...msg, reasoningTimeMs }
: msg
);
}
} else {
// Show error in the assistant message
const errorMessage = error instanceof Error ? error.message : "An error occurred";
messages = messages.map((msg, i) =>
i === messages.length - 1
? { ...msg, content: msg.content + `\n\n**Error:** ${errorMessage}` }
: msg
);
}
} finally {
isStreaming = false;
isReasoning = false;
abortController = null;
}
}
async function editMessage(idx: number, newContent: string) {
if (isStreaming || !$selectedModelStore) return;
// Update the user message at the specified index
messages = messages.map((msg, i) =>
i === idx ? { ...msg, content: newContent } : msg
);
// Trigger a new chat request with the updated messages
await regenerateFromIndex(idx);
}
function handleKeyDown(event: KeyboardEvent) {
if (event.key === "Enter" && !event.shiftKey) {
event.preventDefault();
sendMessage();
}
}
const ACCEPTED_IMAGE_FORMATS = ["image/jpeg", "image/png", "image/gif", "image/webp"];
const MAX_IMAGE_SIZE = 20 * 1024 * 1024; // 20MB
const MAX_IMAGES_PER_MESSAGE = 5;
function validateImageFile(file: File): string | null {
if (!ACCEPTED_IMAGE_FORMATS.includes(file.type)) {
return `Invalid file type: ${file.type}. Accepted formats: JPG, PNG, GIF, WEBP`;
}
if (file.size > MAX_IMAGE_SIZE) {
return `File too large: ${(file.size / 1024 / 1024).toFixed(1)}MB. Maximum size: 20MB`;
}
return null;
}
function fileToDataUrl(file: File): Promise<string> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = () => resolve(reader.result as string);
reader.onerror = () => reject(new Error("Failed to read file"));
reader.readAsDataURL(file);
});
}
async function processImageFiles(files: File[]): Promise<void> {
imageError = null;
if (attachedImages.length + files.length > MAX_IMAGES_PER_MESSAGE) {
imageError = `Maximum ${MAX_IMAGES_PER_MESSAGE} images per message`;
return;
}
for (const file of files) {
const error = validateImageFile(file);
if (error) {
imageError = error;
return;
}
}
try {
const dataUrls = await Promise.all(files.map(fileToDataUrl));
attachedImages = [...attachedImages, ...dataUrls];
} catch (error) {
imageError = error instanceof Error ? error.message : "Failed to process images";
}
}
function handleImageSelect(event: Event) {
const input = event.target as HTMLInputElement;
if (input.files && input.files.length > 0) {
processImageFiles(Array.from(input.files));
}
// Reset the input so the same file can be selected again
input.value = "";
}
function removeImage(idx: number) {
attachedImages = attachedImages.filter((_, i) => i !== idx);
imageError = null;
}
</script>
<div class="flex flex-col h-full">
<!-- Model selector and controls -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a model..." disabled={isStreaming} />
<div class="flex gap-2">
<button
class="btn"
onclick={() => (showSettings = !showSettings)}
title="Settings"
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
<path fill-rule="evenodd" d="M8.34 1.804A1 1 0 0 1 9.32 1h1.36a1 1 0 0 1 .98.804l.295 1.473c.497.144.971.342 1.416.587l1.25-.834a1 1 0 0 1 1.262.125l.962.962a1 1 0 0 1 .125 1.262l-.834 1.25c.245.445.443.919.587 1.416l1.473.295a1 1 0 0 1 .804.98v1.36a1 1 0 0 1-.804.98l-1.473.295a6.95 6.95 0 0 1-.587 1.416l.834 1.25a1 1 0 0 1-.125 1.262l-.962.962a1 1 0 0 1-1.262.125l-1.25-.834a6.953 6.953 0 0 1-1.416.587l-.295 1.473a1 1 0 0 1-.98.804H9.32a1 1 0 0 1-.98-.804l-.295-1.473a6.957 6.957 0 0 1-1.416-.587l-1.25.834a1 1 0 0 1-1.262-.125l-.962-.962a1 1 0 0 1-.125-1.262l.834-1.25a6.957 6.957 0 0 1-.587-1.416l-1.473-.295A1 1 0 0 1 1 10.68V9.32a1 1 0 0 1 .804-.98l1.473-.295c.144-.497.342-.971.587-1.416l-.834-1.25a1 1 0 0 1 .125-1.262l.962-.962A1 1 0 0 1 5.38 3.03l1.25.834a6.957 6.957 0 0 1 1.416-.587l.294-1.473ZM13 10a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z" clip-rule="evenodd" />
</svg>
</button>
<button class="btn" onclick={newChat} disabled={messages.length === 0 && !isStreaming}>
New Chat
</button>
</div>
</div>
<!-- Settings panel -->
{#if showSettings}
<div class="shrink-0 mb-4 p-4 bg-surface border border-gray-200 dark:border-white/10 rounded">
<div class="mb-4">
<label class="block text-sm font-medium mb-1" for="system-prompt">System Prompt</label>
<textarea
id="system-prompt"
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
placeholder="You are a helpful assistant..."
rows="3"
bind:value={$systemPromptStore}
disabled={isStreaming}
></textarea>
</div>
<div>
<label class="block text-sm font-medium mb-1" for="temperature">
Temperature: {$temperatureStore.toFixed(2)}
</label>
<input
id="temperature"
type="range"
min="0"
max="2"
step="0.05"
class="w-full"
bind:value={$temperatureStore}
disabled={isStreaming}
/>
<div class="flex justify-between text-xs text-txtsecondary mt-1">
<span>Precise (0)</span>
<span>Creative (2)</span>
</div>
</div>
</div>
{/if}
<!-- Empty state for no models configured -->
{#if !hasModels}
<div class="flex-1 flex items-center justify-center text-txtsecondary">
<p>No models configured. Add models to your configuration to start chatting.</p>
</div>
{:else}
<!-- Messages area -->
<div
class="flex-1 overflow-y-auto mb-4 px-2"
bind:this={messagesContainer}
onscroll={handleMessagesScroll}
>
{#if messages.length === 0}
<div class="h-full flex items-center justify-center text-txtsecondary">
<p>Start a conversation by typing a message below.</p>
</div>
{:else}
{#each messages as message, idx (idx)}
<ChatMessageComponent
role={message.role}
content={message.content}
reasoning_content={message.reasoning_content}
reasoningTimeMs={message.reasoningTimeMs}
isStreaming={isStreaming && idx === messages.length - 1 && message.role === "assistant"}
isReasoning={isReasoning && idx === messages.length - 1 && message.role === "assistant"}
onEdit={message.role === "user" ? (newContent) => editMessage(idx, newContent) : undefined}
onRegenerate={message.role === "assistant" && idx > 0 && messages[idx - 1].role === "user"
? () => regenerateFromIndex(idx - 1)
: undefined}
/>
{/each}
{/if}
</div>
<!-- Input area -->
<div class="shrink-0">
<!-- Image preview strip -->
{#if attachedImages.length > 0}
<div class="mb-2 flex flex-wrap gap-2">
{#each attachedImages as imageUrl, idx (idx)}
<div class="relative group">
<img
src={imageUrl}
alt="Attached image {idx + 1}"
class="w-20 h-20 object-cover rounded border border-gray-200 dark:border-white/10"
/>
<button
class="absolute -top-2 -right-2 bg-red-500 text-white rounded-full w-6 h-6 flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
onclick={() => removeImage(idx)}
title="Remove image"
>
×
</button>
</div>
{/each}
</div>
{/if}
<!-- Error message -->
{#if imageError}
<div class="mb-2 p-2 bg-red-100 dark:bg-red-900/20 text-red-700 dark:text-red-400 rounded text-sm">
{imageError}
</div>
{/if}
<div class="flex gap-2">
<!-- Hidden file input -->
<input
type="file"
accept=".jpg,.jpeg,.png,.gif,.webp"
multiple
class="hidden"
bind:this={fileInput}
onchange={handleImageSelect}
/>
<ExpandableTextarea
bind:value={userInput}
placeholder="Type a message..."
rows={3}
onkeydown={handleKeyDown}
disabled={isStreaming || !$selectedModelStore}
/>
<div class="flex flex-col gap-2">
{#if isStreaming}
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelStreaming}>
Cancel
</button>
{:else}
<button
class="btn"
onclick={() => fileInput?.click()}
disabled={isStreaming || !$selectedModelStore}
title="Attach image"
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
<path fill-rule="evenodd" d="M1 5.25A2.25 2.25 0 0 1 3.25 3h13.5A2.25 2.25 0 0 1 19 5.25v9.5A2.25 2.25 0 0 1 16.75 17H3.25A2.25 2.25 0 0 1 1 14.75v-9.5Zm1.5 5.81v3.69c0 .414.336.75.75.75h13.5a.75.75 0 0 0 .75-.75v-2.69l-2.22-2.219a.75.75 0 0 0-1.06 0l-1.91 1.909.47.47a.75.75 0 1 1-1.06 1.06L6.53 8.091a.75.75 0 0 0-1.06 0l-2.97 2.97ZM12 7a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z" clip-rule="evenodd" />
</svg>
</button>
<button
class="btn bg-primary text-btn-primary-text hover:opacity-90"
onclick={sendMessage}
disabled={(!userInput.trim() && attachedImages.length === 0) || !$selectedModelStore}
>
Send
</button>
{/if}
</div>
</div>
</div>
{/if}
</div>
@@ -0,0 +1,467 @@
<script lang="ts">
import { renderMarkdown, escapeHtml, renderStreamingMarkdown, createStreamingCache } from "../../lib/markdown";
import type { RenderedBlock } from "../../lib/markdown";
import { Copy, Check, Pencil, X, Save, RefreshCw, ChevronDown, ChevronRight, Brain, Code } from "lucide-svelte";
import { getTextContent, getImageUrls } from "../../lib/types";
import type { ContentPart } from "../../lib/types";
interface Props {
role: "user" | "assistant" | "system";
content: string | ContentPart[];
reasoning_content?: string;
reasoningTimeMs?: number;
isStreaming?: boolean;
isReasoning?: boolean;
onEdit?: (newContent: string) => void;
onRegenerate?: () => void;
}
let { role, content, reasoning_content = "", reasoningTimeMs = 0, isStreaming = false, isReasoning = false, onEdit, onRegenerate }: Props = $props();
let textContent = $derived(getTextContent(content));
let imageUrls = $derived(getImageUrls(content));
let hasImages = $derived(imageUrls.length > 0);
let canEdit = $derived(onEdit !== undefined && !hasImages);
let streamingCache = createStreamingCache();
let renderedParts = $derived.by(() => {
if (role !== "assistant") {
return { blocks: [{ id: -1, html: escapeHtml(textContent).replace(/\n/g, '<br>') }] as RenderedBlock[], pendingHtml: "" };
}
if (!isStreaming) {
streamingCache = createStreamingCache();
return { blocks: [{ id: -1, html: renderMarkdown(textContent) }] as RenderedBlock[], pendingHtml: "" };
}
return renderStreamingMarkdown(textContent, streamingCache);
});
let copied = $state(false);
let showRaw = $state(false);
let isEditing = $state(false);
let editContent = $state("");
let showReasoning = $state(false);
let modalImageUrl = $state<string | null>(null);
function formatDuration(ms: number): string {
if (ms < 1000) {
return `${ms.toFixed(0)}ms`;
}
return `${(ms / 1000).toFixed(1)}s`;
}
async function copyToClipboard() {
try {
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(textContent);
} else {
// Fallback for non-secure contexts (HTTP)
const textarea = document.createElement("textarea");
textarea.value = textContent;
textarea.style.position = "fixed";
textarea.style.left = "-9999px";
document.body.appendChild(textarea);
textarea.select();
document.execCommand("copy");
document.body.removeChild(textarea);
}
copied = true;
setTimeout(() => (copied = false), 2000);
} catch (err) {
console.error("Failed to copy:", err);
}
}
function startEdit() {
editContent = textContent;
isEditing = true;
}
function cancelEdit() {
isEditing = false;
editContent = "";
}
function saveEdit() {
if (onEdit && editContent.trim() !== textContent) {
onEdit(editContent.trim());
}
isEditing = false;
editContent = "";
}
function openModal(imageUrl: string) {
modalImageUrl = imageUrl;
document.body.style.overflow = "hidden";
}
function closeModal(event?: MouseEvent) {
// Only close if clicking the background, not the image
if (event && event.target !== event.currentTarget) {
return;
}
modalImageUrl = null;
document.body.style.overflow = "";
}
function handleModalKeyDown(event: KeyboardEvent) {
if (event.key === "Escape") {
closeModal();
}
}
function handleKeyDown(event: KeyboardEvent) {
if (event.key === "Enter" && !event.shiftKey) {
event.preventDefault();
saveEdit();
} else if (event.key === "Escape") {
cancelEdit();
}
}
const COPY_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>`;
const CHECK_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M20 6 9 17l-5-5"/></svg>`;
function codeBlockCopy(node: HTMLElement) {
function attachButtons() {
node.querySelectorAll<HTMLPreElement>('pre:not([data-copy-btn])').forEach(pre => {
pre.setAttribute('data-copy-btn', 'true');
const btn = document.createElement('button');
btn.className = 'code-copy-btn';
btn.title = 'Copy code';
btn.innerHTML = COPY_SVG;
btn.addEventListener('click', async () => {
const text = pre.querySelector('code')?.textContent ?? pre.textContent ?? '';
try {
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(text);
} else {
const ta = document.createElement('textarea');
ta.value = text;
ta.style.cssText = 'position:fixed;left:-9999px';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
}
btn.innerHTML = CHECK_SVG;
btn.classList.add('copied');
setTimeout(() => { btn.innerHTML = COPY_SVG; btn.classList.remove('copied'); }, 2000);
} catch (e) {
console.error('copy failed', e);
}
});
pre.appendChild(btn);
});
}
attachButtons();
const mo = new MutationObserver(attachButtons);
mo.observe(node, { childList: true, subtree: true });
return { destroy: () => mo.disconnect() };
}
</script>
<div class="flex {role === 'user' ? 'justify-end' : 'justify-start'} mb-4">
<div
class="relative group rounded-lg px-4 py-2 {role === 'user'
? 'max-w-[85%] bg-primary text-btn-primary-text'
: 'w-full sm:w-4/5 bg-surface border border-gray-200 dark:border-white/10'}"
>
{#if role === "assistant"}
{#if reasoning_content || isReasoning}
<div class="mb-3 border border-gray-200 dark:border-white/10 rounded overflow-hidden">
<button
class="w-full flex items-center gap-2 px-3 py-2 bg-gray-50 dark:bg-white/5 hover:bg-gray-100 dark:hover:bg-white/10 transition-colors text-sm"
onclick={() => showReasoning = !showReasoning}
>
{#if showReasoning}
<ChevronDown class="w-4 h-4" />
{:else}
<ChevronRight class="w-4 h-4" />
{/if}
<Brain class="w-4 h-4" />
<span class="font-medium">Reasoning</span>
<span class="text-txtsecondary ml-2">
({reasoning_content.length} chars{#if !isReasoning && reasoningTimeMs > 0}, {formatDuration(reasoningTimeMs)}{/if})
</span>
{#if isReasoning}
<span class="ml-auto flex items-center gap-1 text-txtsecondary">
<span class="w-1.5 h-1.5 bg-primary rounded-full animate-pulse"></span>
reasoning...
</span>
{/if}
</button>
{#if showReasoning}
<div class="px-3 py-2 bg-gray-50/50 dark:bg-white/[0.02] text-sm text-txtsecondary whitespace-pre-wrap font-mono">
{reasoning_content}{#if isReasoning}<span class="inline-block w-1.5 h-4 bg-current animate-pulse ml-0.5"></span>{/if}
</div>
{/if}
</div>
{/if}
{#if hasImages}
<div class="mb-3 flex flex-wrap gap-2">
{#each imageUrls as imageUrl, idx (idx)}
<button
onclick={() => openModal(imageUrl)}
class="cursor-pointer rounded border border-gray-200 dark:border-white/10 hover:opacity-80 transition-opacity"
>
<img
src={imageUrl}
alt="Image {idx + 1}"
class="max-h-64 rounded"
/>
</button>
{/each}
</div>
{/if}
{#if showRaw}
<div class="whitespace-pre-wrap font-mono text-sm">{textContent}</div>
{:else}
<div class="prose prose-sm dark:prose-invert max-w-none" use:codeBlockCopy>
{#each renderedParts.blocks as block (block.id)}
{@html block.html}
{/each}
{@html renderedParts.pendingHtml}
{#if isStreaming && !isReasoning}
<span class="inline-block w-2 h-4 bg-current animate-pulse ml-0.5"></span>
{/if}
</div>
{/if}
{#if !isStreaming}
<div class="flex gap-1 mt-2 pt-1 border-t border-gray-200 dark:border-white/10">
{#if onRegenerate}
<button
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
onclick={onRegenerate}
title="Regenerate response"
>
<RefreshCw class="w-4 h-4" />
</button>
{/if}
<button
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
onclick={copyToClipboard}
title={copied ? "Copied!" : "Copy to clipboard"}
>
{#if copied}
<Check class="w-4 h-4 text-green-500" />
{:else}
<Copy class="w-4 h-4" />
{/if}
</button>
<button
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 {showRaw ? 'text-primary' : 'text-txtsecondary'}"
onclick={() => showRaw = !showRaw}
title={showRaw ? "Show rendered" : "Show raw"}
>
<Code class="w-4 h-4" />
</button>
</div>
{/if}
{:else}
{#if isEditing}
<div class="flex flex-col gap-2 min-w-[300px]">
<textarea
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface text-txtmain focus:outline-none focus:ring-2 focus:ring-primary resize-none"
rows="3"
bind:value={editContent}
onkeydown={handleKeyDown}
></textarea>
<div class="flex justify-end gap-2">
<button
class="p-1.5 rounded hover:bg-white/20"
onclick={cancelEdit}
title="Cancel"
>
<X class="w-4 h-4" />
</button>
<button
class="p-1.5 rounded hover:bg-white/20"
onclick={saveEdit}
title="Save"
>
<Save class="w-4 h-4" />
</button>
</div>
</div>
{:else}
{#if hasImages}
<div class="mb-2 flex flex-wrap gap-2">
{#each imageUrls as imageUrl, idx (idx)}
<button
onclick={() => openModal(imageUrl)}
class="cursor-pointer rounded border border-white/20 hover:opacity-80 transition-opacity"
>
<img
src={imageUrl}
alt="Image {idx + 1}"
class="max-w-[200px] rounded"
/>
</button>
{/each}
</div>
{/if}
<div class="whitespace-pre-wrap pr-8">{textContent}</div>
{#if canEdit}
<button
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-0 group-hover:opacity-100 transition-opacity bg-white/20 hover:bg-white/30 shadow-sm"
onclick={startEdit}
title="Edit message"
>
<Pencil class="w-4 h-4" />
</button>
{/if}
{/if}
{/if}
</div>
</div>
<!-- Full-size image modal -->
{#if modalImageUrl}
<div
class="fixed inset-0 z-50 flex items-center justify-center bg-black/80 p-4"
onclick={(e) => closeModal(e)}
onkeydown={handleModalKeyDown}
role="button"
tabindex="-1"
>
<button
class="absolute top-4 right-4 p-2 rounded-lg bg-white/10 hover:bg-white/20 text-white transition-colors"
onclick={() => closeModal()}
title="Close"
>
<X class="w-6 h-6" />
</button>
<img
src={modalImageUrl}
alt=""
class="max-w-full max-h-full rounded pointer-events-none"
/>
</div>
{/if}
<style>
.prose :global(pre) {
position: relative;
background-color: var(--color-surface);
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
border-radius: 0.375rem;
padding: 0.75rem;
padding-right: 2.5rem;
overflow-x: auto;
margin: 0.5rem 0;
}
.prose :global(.code-copy-btn) {
position: absolute;
top: 0.375rem;
right: 0.375rem;
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
border-radius: 0.25rem;
border: 1px solid var(--color-border);
background: var(--color-surface);
color: var(--color-txtsecondary);
cursor: pointer;
transition: background-color 0.15s;
line-height: 0;
}
.prose :global(.code-copy-btn:hover) {
background: var(--color-secondary);
}
.prose :global(.code-copy-btn.copied) {
color: var(--color-success);
opacity: 1;
}
.prose :global(code) {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 0.875em;
}
.prose :global(pre code) {
background: none;
padding: 0;
}
.prose :global(code:not(pre code)) {
background-color: var(--color-surface);
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
}
.prose :global(p) {
margin: 0.5rem 0;
}
.prose :global(p:first-child) {
margin-top: 0;
}
.prose :global(p:last-child) {
margin-bottom: 0;
}
.prose :global(ul),
.prose :global(ol) {
margin: 0.5rem 0;
padding-left: 1.5rem;
}
.prose :global(li) {
margin: 0.25rem 0;
}
.prose :global(h1),
.prose :global(h2),
.prose :global(h3),
.prose :global(h4) {
margin: 1rem 0 0.5rem 0;
font-weight: 600;
}
.prose :global(h1:first-child),
.prose :global(h2:first-child),
.prose :global(h3:first-child),
.prose :global(h4:first-child) {
margin-top: 0;
}
.prose :global(blockquote) {
border-left: 3px solid var(--color-primary);
padding-left: 1rem;
margin: 0.5rem 0;
font-style: italic;
}
.prose :global(a) {
color: var(--color-primary);
text-decoration: underline;
}
.prose :global(table) {
width: 100%;
border-collapse: collapse;
margin: 0.5rem 0;
}
.prose :global(th),
.prose :global(td) {
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
padding: 0.5rem;
text-align: left;
}
.prose :global(th) {
background-color: var(--color-surface);
font-weight: 600;
}
/* Highlight.js theme overrides for dark mode */
:global(.dark) .prose :global(.hljs) {
background: transparent;
}
</style>
@@ -0,0 +1,121 @@
<script lang="ts">
import { untrack } from "svelte";
import { Maximize2, X } from "lucide-svelte";
interface Props {
value: string;
placeholder?: string;
rows?: number;
disabled?: boolean;
onkeydown?: (event: KeyboardEvent) => void;
}
let {
value = $bindable(),
placeholder = "",
rows = 3,
disabled = false,
onkeydown,
}: Props = $props();
let isExpanded = $state(false);
let expandedValue = $state("");
let expandedTextarea: HTMLTextAreaElement | undefined = $state();
function openExpanded() {
expandedValue = value;
isExpanded = true;
}
function closeExpanded() {
isExpanded = false;
}
function saveExpanded() {
value = expandedValue;
isExpanded = false;
}
function handleKeyDown(event: KeyboardEvent) {
if (event.key === "Escape") {
closeExpanded();
}
}
// Focus the textarea when expanded view opens
$effect(() => {
if (isExpanded && expandedTextarea) {
expandedTextarea.focus();
const len = untrack(() => expandedValue.length);
expandedTextarea.setSelectionRange(len, len);
}
});
</script>
<div class="flex-1 relative group flex items-stretch min-h-0">
<textarea
class="w-full px-3 py-2 pr-10 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-inset focus:ring-primary resize-none"
{placeholder}
{rows}
bind:value
{onkeydown}
{disabled}
></textarea>
<button
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-60 md:opacity-0 group-hover:opacity-100 transition-opacity bg-surface/90 hover:bg-surface border border-gray-200 dark:border-white/10 shadow-sm"
onclick={openExpanded}
title="Expand to edit"
type="button"
{disabled}
>
<Maximize2 class="w-4 h-4" />
</button>
</div>
{#if isExpanded}
<div class="fixed inset-0 z-50 flex items-center justify-center bg-black/50 p-4">
<div class="w-full max-w-4xl h-[80vh] flex flex-col bg-surface rounded-lg shadow-xl border border-gray-200 dark:border-white/10">
<!-- Header -->
<div class="flex justify-between items-center p-4 border-b border-gray-200 dark:border-white/10">
<h3 class="font-medium">Edit Text</h3>
<button
class="p-1.5 rounded-lg hover:bg-gray-100 dark:hover:bg-white/10"
onclick={closeExpanded}
title="Close"
type="button"
>
<X class="w-5 h-5" />
</button>
</div>
<!-- Textarea -->
<div class="flex-1 p-4">
<textarea
bind:this={expandedTextarea}
class="w-full h-full px-4 py-3 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
placeholder={placeholder}
bind:value={expandedValue}
onkeydown={handleKeyDown}
></textarea>
</div>
<!-- Footer -->
<div class="flex justify-end gap-2 p-4 border-t border-gray-200 dark:border-white/10">
<button
class="btn"
onclick={closeExpanded}
type="button"
>
Cancel
</button>
<button
class="btn bg-primary text-btn-primary-text hover:opacity-90"
onclick={saveExpanded}
type="button"
>
Done
</button>
</div>
</div>
</div>
{/if}
@@ -0,0 +1,521 @@
<script lang="ts">
import { models } from "../../stores/api";
import { persistentStore } from "../../stores/persistent";
import { generateImage } from "../../lib/imageApi";
import { generateSdImage, fetchSdLoras } from "../../lib/sdApi";
import { playgroundStores } from "../../stores/playgroundActivity";
import ModelSelector from "./ModelSelector.svelte";
import ExpandableTextarea from "./ExpandableTextarea.svelte";
import type { ImageApiMode, SdApiLora, SdApiLoraRef } from "../../lib/types";
const selectedModelStore = persistentStore<string>("playground-image-model", "");
const selectedSizeStore = persistentStore<string>("playground-image-size", "1024x1024");
const apiModeStore = persistentStore<ImageApiMode>("playground-image-api-mode", "openai");
// SDAPI persistent settings
const sdNegativePromptStore = persistentStore<string>("playground-sdapi-negative-prompt", "");
const sdStepsStore = persistentStore<number>("playground-sdapi-steps", 20);
const sdCfgScaleStore = persistentStore<number>("playground-sdapi-cfg-scale", 7);
const sdSeedStore = persistentStore<number>("playground-sdapi-seed", -1);
const sdSamplerStore = persistentStore<string>("playground-sdapi-sampler", "");
const sdSchedulerStore = persistentStore<string>("playground-sdapi-scheduler", "");
const sdBatchSizeStore = persistentStore<number>("playground-sdapi-batch-size", 1);
let prompt = $state("");
let isGenerating = $state(false);
let generatedImages = $state<string[]>([]);
let error = $state<string | null>(null);
let abortController = $state<AbortController | null>(null);
let showFullscreen = $state(false);
let fullscreenIndex = $state(0);
let showSettings = $state(false);
// SDAPI lora state
let availableLoras = $state<SdApiLora[]>([]);
let selectedLoras = $state<SdApiLoraRef[]>([]);
let isLoadingLoras = $state(false);
let lorasLoaded = $state(false);
let loraError = $state<string | null>(null);
let hasModels = $derived($models.some((m) => !m.unlisted));
let isSdapi = $derived($apiModeStore === "sdapi");
$effect(() => {
playgroundStores.imageGenerating.set(isGenerating);
});
async function loadLoras() {
if (!$selectedModelStore || isLoadingLoras) return;
isLoadingLoras = true;
loraError = null;
try {
const loras = await fetchSdLoras($selectedModelStore);
availableLoras = loras;
lorasLoaded = true;
} catch (err) {
availableLoras = [];
loraError = err instanceof Error ? err.message : "Failed to load LoRAs";
lorasLoaded = false;
} finally {
isLoadingLoras = false;
}
}
function addLora(event: Event) {
const select = event.target as HTMLSelectElement;
const path = select.value;
if (!path) return;
const lora = availableLoras.find((l) => l.path === path);
if (lora && !selectedLoras.some((l) => l.path === path)) {
selectedLoras = [...selectedLoras, { path: lora.path, multiplier: 1.0 }];
}
select.value = "";
}
function removeLora(path: string) {
selectedLoras = selectedLoras.filter((l) => l.path !== path);
}
function updateLoraMultiplier(path: string, multiplier: number) {
selectedLoras = selectedLoras.map((l) =>
l.path === path ? { ...l, multiplier } : l
);
}
function getLoraName(path: string): string {
return availableLoras.find((l) => l.path === path)?.name ?? path;
}
async function generate() {
const trimmedPrompt = prompt.trim();
if (!trimmedPrompt || !$selectedModelStore || isGenerating) return;
isGenerating = true;
error = null;
abortController = new AbortController();
try {
if (isSdapi) {
const [w, h] = $selectedSizeStore.split("x").map(Number);
const request = {
model: $selectedModelStore,
prompt: trimmedPrompt,
negative_prompt: $sdNegativePromptStore || undefined,
width: w,
height: h,
steps: $sdStepsStore,
cfg_scale: $sdCfgScaleStore,
seed: $sdSeedStore,
batch_size: $sdBatchSizeStore,
sampler_name: $sdSamplerStore || undefined,
scheduler: $sdSchedulerStore || undefined,
lora: selectedLoras.length > 0 ? selectedLoras : undefined,
};
const response = await generateSdImage(request, abortController.signal);
if (response.images && response.images.length > 0) {
generatedImages = response.images.map(
(img) => `data:image/png;base64,${img}`
);
}
} else {
const response = await generateImage(
$selectedModelStore,
trimmedPrompt,
$selectedSizeStore,
abortController.signal
);
if (response.data && response.data.length > 0) {
const imageData = response.data[0];
if (imageData.b64_json) {
generatedImages = [`data:image/png;base64,${imageData.b64_json}`];
} else if (imageData.url) {
generatedImages = [imageData.url];
}
}
}
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
// User cancelled
} else {
error = err instanceof Error ? err.message : "An error occurred";
}
} finally {
isGenerating = false;
abortController = null;
}
}
function cancelGeneration() {
abortController?.abort();
}
function clearImage() {
generatedImages = [];
error = null;
prompt = "";
}
function downloadImage(index: number = 0) {
const img = generatedImages[index];
if (!img) return;
const link = document.createElement("a");
link.href = img;
link.download = `generated-image-${Date.now()}-${index}.png`;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
}
function openFullscreen(index: number = 0) {
fullscreenIndex = index;
showFullscreen = true;
}
function closeFullscreen(event?: MouseEvent) {
if (event && event.target !== event.currentTarget) {
return;
}
showFullscreen = false;
}
function handleKeyDown(event: KeyboardEvent) {
if (event.key === "Enter" && !event.shiftKey) {
event.preventDefault();
generate();
}
}
</script>
<div class="flex flex-col h-full">
<!-- Model selector and mode toggle -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} />
<select
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$apiModeStore}
disabled={isGenerating}
>
<option value="openai">OpenAI</option>
<option value="sdapi">SDAPI</option>
</select>
<select
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$selectedSizeStore}
disabled={isGenerating}
>
<optgroup label="Square">
<option value="512x512">512x512</option>
<option value="1024x1024">1024x1024</option>
</optgroup>
<optgroup label="Landscape">
<option value="1024x768">1024x768 (4:3)</option>
<option value="1280x720">1280x720 (16:9)</option>
<option value="1792x1024">1792x1024 (SDXL)</option>
</optgroup>
<optgroup label="Portrait">
<option value="768x1024">768x1024 (3:4)</option>
<option value="720x1280">720x1280 (9:16)</option>
<option value="1024x1792">1024x1792 (SDXL)</option>
</optgroup>
</select>
{#if isSdapi}
<button
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface hover:bg-secondary-hover transition-colors"
onclick={() => showSettings = !showSettings}
>
{showSettings ? "Hide Settings" : "Settings"}
</button>
{/if}
</div>
<!-- SDAPI Settings Panel -->
{#if isSdapi && showSettings}
<div class="shrink-0 mb-4 p-4 rounded border border-gray-200 dark:border-white/10 bg-surface">
<div class="grid grid-cols-2 md:grid-cols-4 gap-3 mb-3">
<label class="flex flex-col gap-1">
<span class="text-xs text-txtsecondary">Steps</span>
<input
type="number"
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$sdStepsStore}
min="1"
max="150"
/>
</label>
<label class="flex flex-col gap-1">
<span class="text-xs text-txtsecondary">CFG Scale</span>
<input
type="number"
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$sdCfgScaleStore}
min="1"
max="30"
step="0.5"
/>
</label>
<label class="flex flex-col gap-1">
<span class="text-xs text-txtsecondary">Seed (-1 = random)</span>
<input
type="number"
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$sdSeedStore}
min="-1"
/>
</label>
<label class="flex flex-col gap-1">
<span class="text-xs text-txtsecondary">Batch Size</span>
<input
type="number"
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$sdBatchSizeStore}
min="1"
max="8"
/>
</label>
<label class="flex flex-col gap-1">
<span class="text-xs text-txtsecondary">Sampler</span>
<select
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$sdSamplerStore}
>
<option value="">Default</option>
<option value="euler_a">euler_a</option>
<option value="euler">euler</option>
<option value="heun">heun</option>
<option value="dpm2">dpm2</option>
<option value="dpmpp2s_a">dpmpp2s_a</option>
<option value="dpmpp2m">dpmpp2m</option>
<option value="dpmpp2mv2">dpmpp2mv2</option>
<option value="ipndm">ipndm</option>
<option value="ipndm_v">ipndm_v</option>
<option value="lcm">lcm</option>
<option value="ddim_trailing">ddim_trailing</option>
<option value="tcd">tcd</option>
</select>
</label>
<label class="flex flex-col gap-1">
<span class="text-xs text-txtsecondary">Scheduler</span>
<select
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value={$sdSchedulerStore}
>
<option value="">Auto for model</option>
<option value="discrete">discrete</option>
<option value="karras">karras</option>
<option value="exponential">exponential</option>
<option value="ays">ays</option>
<option value="gits">gits</option>
</select>
</label>
</div>
<label class="flex flex-col gap-1 mb-3">
<span class="text-xs text-txtsecondary">Negative Prompt</span>
<textarea
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary resize-y text-sm"
bind:value={$sdNegativePromptStore}
rows="2"
placeholder="Elements to avoid..."
></textarea>
</label>
<!-- LoRA Selection -->
<div>
<span class="text-xs text-txtsecondary block mb-1">LoRAs</span>
<div class="flex items-center gap-2 mb-2">
<button
class="px-3 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface hover:bg-secondary-hover transition-colors disabled:opacity-50"
onclick={loadLoras}
disabled={!$selectedModelStore || isLoadingLoras}
>
{isLoadingLoras ? "Loading..." : lorasLoaded ? "Reload LoRAs" : "Load LoRAs"}
</button>
{#if lorasLoaded && availableLoras.length > 0}
<select
class="flex-1 px-2 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
onchange={addLora}
>
<option value="">Add a LoRA...</option>
{#each availableLoras.filter((l) => !selectedLoras.some((s) => s.path === l.path)) as lora}
<option value={lora.path}>{lora.name}</option>
{/each}
</select>
{/if}
</div>
{#if loraError}
<p class="text-xs text-red-500 mb-1">{loraError}</p>
{/if}
{#if lorasLoaded && availableLoras.length === 0}
<p class="text-xs text-txtsecondary">No LoRAs available</p>
{/if}
{#if selectedLoras.length > 0}
<div class="flex flex-col gap-1.5">
{#each selectedLoras as lora}
<div class="flex items-center gap-2 text-sm">
<span class="flex-1 truncate">{getLoraName(lora.path)}</span>
<input
type="number"
class="w-20 px-1.5 py-1 text-xs rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-1 focus:ring-primary"
value={lora.multiplier}
oninput={(e) => updateLoraMultiplier(lora.path, parseFloat((e.target as HTMLInputElement).value) || 1)}
min="0"
max="2"
step="0.1"
/>
<button
class="px-1.5 py-0.5 text-xs rounded border border-gray-200 dark:border-white/10 hover:bg-red-500 hover:text-white hover:border-red-500 transition-colors"
onclick={() => removeLora(lora.path)}
aria-label="Remove LoRA"
>
x
</button>
</div>
{/each}
</div>
{/if}
</div>
</div>
{/if}
<!-- Empty state for no models configured -->
{#if !hasModels}
<div class="flex-1 flex items-center justify-center text-txtsecondary">
<p>No models configured. Add models to your configuration to generate images.</p>
</div>
{:else}
<!-- Image display area -->
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
{#if isGenerating}
<div class="text-center text-txtsecondary">
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
<p>Generating image...</p>
</div>
{:else if error}
<div class="text-center text-red-500 p-4">
<p class="font-medium">Error</p>
<p class="text-sm mt-1">{error}</p>
</div>
{:else if generatedImages.length > 1}
<!-- Grid for multiple images (batch) -->
<div class="grid grid-cols-2 gap-2 p-2 w-full h-full overflow-auto">
{#each generatedImages as img, i}
<div class="relative flex items-center justify-center">
<button
class="p-0 border-0 bg-transparent cursor-pointer"
onclick={() => openFullscreen(i)}
aria-label="View fullscreen"
>
<img
src={img}
alt="AI generated content {i + 1}"
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
/>
</button>
<button
class="absolute bottom-2 right-2 p-1.5 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
onclick={(e) => { e.stopPropagation(); downloadImage(i); }}
aria-label="Download image"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
</svg>
</button>
</div>
{/each}
</div>
{:else if generatedImages.length === 1}
<div class="relative max-w-full max-h-full flex items-center justify-center">
<button
class="p-0 border-0 bg-transparent cursor-pointer"
onclick={() => openFullscreen(0)}
aria-label="View fullscreen"
>
<img
src={generatedImages[0]}
alt="AI generated content"
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
/>
</button>
<button
class="absolute bottom-2 right-2 p-2 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
onclick={(e) => { e.stopPropagation(); downloadImage(0); }}
aria-label="Download image"
>
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
</svg>
</button>
</div>
{:else}
<div class="text-center text-txtsecondary">
<p>Enter a prompt below to generate an image</p>
</div>
{/if}
</div>
<!-- Prompt input area -->
<div class="shrink-0 flex flex-col md:flex-row gap-2">
<ExpandableTextarea
bind:value={prompt}
placeholder="Describe the image you want to generate..."
rows={3}
onkeydown={handleKeyDown}
disabled={isGenerating || !$selectedModelStore}
/>
<div class="flex flex-row md:flex-col gap-2">
{#if isGenerating}
<button class="btn bg-red-500 hover:bg-red-600 text-white flex-1 md:flex-none" onclick={cancelGeneration}>
Cancel
</button>
{:else}
<button
class="btn bg-primary text-btn-primary-text hover:opacity-90 flex-1 md:flex-none"
onclick={generate}
disabled={!prompt.trim() || !$selectedModelStore}
>
Generate
</button>
<button
class="btn flex-1 md:flex-none"
onclick={clearImage}
disabled={generatedImages.length === 0 && !error && !prompt.trim()}
>
Clear
</button>
{/if}
</div>
</div>
{/if}
</div>
<!-- Fullscreen dialog -->
{#if showFullscreen && generatedImages[fullscreenIndex]}
<div
class="fixed inset-0 bg-black/90 z-50 flex items-center justify-center p-4"
onclick={(e) => closeFullscreen(e)}
onkeydown={(e) => e.key === 'Escape' && closeFullscreen()}
role="dialog"
aria-modal="true"
tabindex="-1"
>
<button
class="absolute top-4 right-4 text-white hover:text-gray-300 text-2xl w-10 h-10 flex items-center justify-center rounded-full hover:bg-white/10 transition-colors"
onclick={() => closeFullscreen()}
aria-label="Close fullscreen"
>
×
</button>
<img
src={generatedImages[fullscreenIndex]}
alt="AI generated content"
class="max-w-full max-h-full object-contain pointer-events-none"
/>
</div>
{/if}
@@ -0,0 +1,44 @@
<script lang="ts">
import { models } from "../../stores/api";
import { groupModels } from "../../lib/modelUtils";
interface Props {
value: string;
placeholder?: string;
disabled?: boolean;
}
let { value = $bindable(), placeholder = "Select a model...", disabled = false }: Props = $props();
let grouped = $derived(groupModels($models));
let hasModels = $derived(grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
</script>
{#if hasModels}
<select
class="min-w-0 flex-1 basis-48 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
bind:value
{disabled}
>
<option value="">{placeholder}</option>
{#if grouped.local.length > 0}
<optgroup label="Local">
{#each grouped.local as model (model.id)}
<option value={model.id}>{model.id}</option>
{#if model.aliases}
{#each model.aliases as alias (alias)}
<option value={alias}> {alias}</option>
{/each}
{/if}
{/each}
</optgroup>
{/if}
{#each Object.entries(grouped.peersByProvider).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
<optgroup label="Peer: {peerId}">
{#each peerModels as model (model.id)}
<option value={model.id}>{model.id}</option>
{/each}
</optgroup>
{/each}
</select>
{/if}
@@ -0,0 +1,14 @@
<script lang="ts">
interface Props {
featureName: string;
}
let { featureName }: Props = $props();
</script>
<div class="flex items-center justify-center h-full">
<div class="text-center text-txtsecondary">
<p class="text-lg">{featureName}</p>
<p class="text-sm mt-2">To be implemented</p>
</div>
</div>
@@ -0,0 +1,406 @@
<script lang="ts">
import { models } from "../../stores/api";
import { persistentStore } from "../../stores/persistent";
import { rerank } from "../../lib/rerankApi";
import { playgroundStores } from "../../stores/playgroundActivity";
import ModelSelector from "./ModelSelector.svelte";
type RerankRow = { doc: string; score: number | null };
type SortOrder = "none" | "asc" | "desc";
type EditorMode = "table" | "json";
const selectedModelStore = persistentStore<string>("playground-rerank-model", "");
const defaultQuery = "How do LLM's work?";
const defaultDocs = [
"Large language models (LLMs) use transformer architectures to predict the next token in a sequence based on massive amounts of text data.",
"LLMs are trained on diverse internet text, learning statistical patterns of language that allow them to generate coherent responses.",
"During training, LLMs minimize a loss function that measures the difference between predicted and actual tokens across billions of examples.",
"Attention mechanisms in transformers enable LLMs to weigh the importance of different words when generating output.",
"Fine\u2011tuning allows a pre\u2011trained LLM to adapt to a specific downstream task with a smaller dataset.",
"Neural networks consist of layers of interconnected neurons that adjust their weights during back\u2011propagation.",
"The history of the Roman Empire spanned over a thousand years.",
"Soccer is the most popular sport in many countries around the world.",
"Quantum computing uses qubits to perform calculations that are intractable for classical computers.",
];
let query = $state(defaultQuery);
let rows = $state<RerankRow[]>([
...defaultDocs.map((doc) => ({ doc, score: null })),
{ doc: "", score: null },
]);
let isLoading = $state(false);
let error = $state<string | null>(null);
let usage = $state<{ prompt_tokens: number; total_tokens: number } | null>(null);
let abortController: AbortController | null = null;
let sortOrder = $state<SortOrder>("desc");
let editorMode = $state<EditorMode>("table");
let jsonText = $state("");
let jsonError = $state<string | null>(null);
let hasModels = $derived($models.some((m) => !m.unlisted));
let canSubmit = $derived((() => {
if (!$selectedModelStore || isLoading) return false;
if (editorMode === "json") {
try {
const parsed = JSON.parse(jsonText) as Record<string, unknown>;
return (
typeof parsed.query === "string" &&
parsed.query.trim() !== "" &&
Array.isArray(parsed.documents) &&
(parsed.documents as unknown[]).some(
(d) => typeof d === "string" && (d as string).trim() !== ""
)
);
} catch {
return false;
}
}
return query.trim() !== "" && rows.some((r) => r.doc.trim() !== "");
})());
// Display rows with sort applied (display-only transform, rows[] is never mutated by sorting)
let displayRows = $derived((() => {
const indexed = rows.map((row, i) => ({ row, i }));
if (sortOrder === "none") return indexed;
return [...indexed].sort((a, b) => {
if (a.row.score === null && b.row.score === null) return 0;
if (a.row.score === null) return 1;
if (b.row.score === null) return -1;
return sortOrder === "desc"
? b.row.score - a.row.score
: a.row.score - b.row.score;
});
})());
// Auto-add a new empty row when the last row gets content (table mode only)
$effect(() => {
if (editorMode === "table" && rows[rows.length - 1]?.doc.trim() !== "") {
rows = [...rows, { doc: "", score: null }];
}
});
// Sync loading state to activity store
$effect(() => {
playgroundStores.rerankLoading.set(isLoading);
});
function switchToJson() {
if (editorMode === "json") return;
const docs = rows.filter((r) => r.doc.trim() !== "").map((r) => r.doc);
jsonText = JSON.stringify({ query, documents: docs }, null, 2);
jsonError = null;
editorMode = "json";
}
function switchToTable() {
if (editorMode === "table") return;
if (jsonText.trim() === "") {
query = "";
rows = [{ doc: "", score: null }];
jsonError = null;
editorMode = "table";
return;
}
try {
const parsed = JSON.parse(jsonText) as unknown;
if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) {
throw new Error("Expected a JSON object");
}
const obj = parsed as Record<string, unknown>;
if (typeof obj.query !== "string") throw new Error('"query" must be a string');
if (!Array.isArray(obj.documents)) throw new Error('"documents" must be an array');
query = obj.query;
const newRows: RerankRow[] = (obj.documents as unknown[]).map((d) => ({
doc: typeof d === "string" ? d : String(d),
score: null,
}));
if (newRows.length === 0 || newRows[newRows.length - 1].doc.trim() !== "") {
newRows.push({ doc: "", score: null });
}
rows = newRows;
jsonError = null;
editorMode = "table";
} catch (err) {
jsonError = err instanceof Error ? err.message : "Invalid JSON";
}
}
function cycleSortOrder() {
sortOrder = sortOrder === "none" ? "desc" : sortOrder === "desc" ? "asc" : "none";
}
function sortIndicator(): string {
if (sortOrder === "desc") return " ↓";
if (sortOrder === "asc") return " ↑";
return "";
}
async function submit() {
if (!canSubmit) return;
let submitQuery: string;
let nonEmptyEntries: { originalIndex: number; doc: string }[];
if (editorMode === "json") {
// Parse JSON, sync state to table, then submit
try {
const parsed = JSON.parse(jsonText) as Record<string, unknown>;
submitQuery = parsed.query as string;
const docs = (parsed.documents as string[]).filter((d) => d.trim() !== "");
const newRows: RerankRow[] = docs.map((d) => ({ doc: d, score: null }));
newRows.push({ doc: "", score: null });
rows = newRows;
query = submitQuery;
editorMode = "table";
} catch {
error = "Invalid JSON — fix before submitting";
return;
}
nonEmptyEntries = rows
.map((r, i) => ({ originalIndex: i, doc: r.doc }))
.filter((e) => e.doc.trim() !== "");
} else {
submitQuery = query;
nonEmptyEntries = rows
.map((r, i) => ({ originalIndex: i, doc: r.doc }))
.filter((e) => e.doc.trim() !== "");
}
isLoading = true;
error = null;
usage = null;
// Clear previous scores
rows = rows.map((r) => ({ ...r, score: null }));
abortController = new AbortController();
try {
const response = await rerank(
$selectedModelStore,
submitQuery,
nonEmptyEntries.map((e) => e.doc),
abortController.signal
);
usage = response.usage;
// Map result.index (position in submitted docs array) back to original rows[] index
const updated = rows.map((r) => ({ ...r }));
for (const result of response.results) {
const entry = nonEmptyEntries[result.index];
if (entry !== undefined) {
updated[entry.originalIndex].score = result.relevance_score;
}
}
rows = updated;
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
// User cancelled
} else {
error = err instanceof Error ? err.message : "An error occurred";
}
} finally {
isLoading = false;
abortController = null;
}
}
function cancel() {
abortController?.abort();
}
function clear() {
query = defaultQuery;
rows = [...defaultDocs.map((doc) => ({ doc, score: null })), { doc: "", score: null }];
error = null;
usage = null;
sortOrder = "desc";
jsonText = "";
jsonError = null;
}
function deleteRow(originalIndex: number) {
if (rows.length <= 1) return;
rows = rows.filter((_, i) => i !== originalIndex);
}
function updateDoc(originalIndex: number, value: string) {
const updated = rows.map((r) => ({ ...r }));
updated[originalIndex].doc = value;
rows = updated;
}
function scoreColor(score: number | null): string {
if (score === null) return "text-txtsecondary";
if (score > 0) return "text-green-600 dark:text-green-400";
return "text-red-500 dark:text-red-400";
}
function formatScore(score: number | null): string {
if (score === null) return "—";
return score.toFixed(3);
}
function handleKeyDown(e: KeyboardEvent) {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
submit();
}
}
let isCleared = $derived(
query === defaultQuery &&
rows.every((r, i) => r.score === null && r.doc === (defaultDocs[i] ?? "")) &&
rows.length === defaultDocs.length + 1 &&
!jsonText.trim() &&
!error &&
!usage
);
</script>
<div class="flex flex-col h-full">
<!-- Top bar: model selector + query input (table mode) + mode toggle -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} />
{#if editorMode === "table"}
<input
type="text"
class="min-w-0 flex-1 basis-48 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
placeholder="Query..."
bind:value={query}
disabled={isLoading}
onkeydown={handleKeyDown}
/>
{/if}
<!-- Table / JSON toggle -->
<div class="flex rounded border border-gray-200 dark:border-white/10 overflow-hidden shrink-0">
<button
class="px-3 py-1.5 text-sm transition-colors {editorMode === 'table'
? 'bg-primary text-btn-primary-text'
: 'bg-surface hover:bg-secondary-hover'}"
onclick={switchToTable}
disabled={isLoading}
>
Table
</button>
<button
class="px-3 py-1.5 text-sm border-l border-gray-200 dark:border-white/10 transition-colors {editorMode === 'json'
? 'bg-primary text-btn-primary-text'
: 'bg-surface hover:bg-secondary-hover'}"
onclick={switchToJson}
disabled={isLoading}
>
JSON
</button>
</div>
</div>
{#if !hasModels}
<div class="flex-1 flex items-center justify-center text-txtsecondary">
<p>No models configured. Add models to your configuration to use reranking.</p>
</div>
{:else if editorMode === "json"}
<!-- JSON editor -->
<div class="flex-1 flex flex-col min-h-0 mb-4">
<textarea
class="flex-1 w-full font-mono text-sm px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary resize-none"
bind:value={jsonText}
disabled={isLoading}
placeholder={'{\n "query": "your search query",\n "documents": [\n "document one",\n "document two"\n ]\n}'}
spellcheck={false}
></textarea>
{#if jsonError}
<p class="mt-1 text-sm text-red-500">{jsonError}</p>
{/if}
</div>
{:else}
<!-- Document table -->
<div class="flex-1 overflow-y-auto mb-4 border border-gray-200 dark:border-white/10 rounded">
<table class="w-full border-collapse table-fixed">
<colgroup>
<col class="w-auto" />
<col style="width: 120px" />
<col style="width: 40px" />
</colgroup>
<thead class="sticky top-0 bg-surface border-b border-gray-200 dark:border-white/10">
<tr>
<th class="px-3 py-2 text-left text-sm font-medium text-txtsecondary">Document</th>
<th
class="px-3 py-2 text-right text-sm font-medium text-txtsecondary cursor-pointer select-none hover:text-txtprimary transition-colors"
onclick={cycleSortOrder}
>
Score{sortIndicator()}
</th>
<th class="px-2 py-2"></th>
</tr>
</thead>
<tbody>
{#each displayRows as { row, i } (i)}
<tr class="border-b border-gray-100 dark:border-white/5 last:border-0">
<td class="px-3 py-1.5">
<input
type="text"
class="w-full bg-transparent focus:outline-none focus:ring-1 focus:ring-primary rounded px-1 py-0.5"
placeholder={i === rows.length - 1 ? "Add document..." : "Document text..."}
value={row.doc}
oninput={(e) => updateDoc(i, (e.target as HTMLInputElement).value)}
disabled={isLoading}
onkeydown={handleKeyDown}
/>
</td>
<td class="px-3 py-1.5 text-right font-mono text-sm {scoreColor(row.score)}">
{#if isLoading && row.score === null && row.doc.trim() !== ""}
<span class="inline-block w-4 h-4 border-2 border-current border-t-transparent rounded-full animate-spin align-middle"></span>
{:else}
{formatScore(row.score)}
{/if}
</td>
<td class="px-2 py-1.5 text-center">
<button
class="w-7 h-7 flex items-center justify-center text-txtsecondary hover:text-red-500 transition-colors rounded disabled:opacity-30 disabled:cursor-not-allowed"
onclick={() => deleteRow(i)}
disabled={rows.length <= 1}
tabindex="-1"
aria-label="Remove row"
>
×
</button>
</td>
</tr>
{/each}
</tbody>
</table>
</div>
{/if}
<!-- Bottom toolbar -->
{#if hasModels}
<div class="shrink-0 flex flex-wrap items-center gap-2">
{#if isLoading}
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancel}>
Cancel
</button>
{:else}
<button
class="btn bg-primary text-btn-primary-text hover:opacity-90"
onclick={submit}
disabled={!canSubmit}
>
Rerank
</button>
<button class="btn" onclick={clear} disabled={isCleared}>
Clear
</button>
{/if}
{#if error}
<span class="text-sm text-red-500 ml-2">{error}</span>
{:else if usage}
<span class="text-sm text-txtsecondary ml-2">{usage.total_tokens} tokens</span>
{/if}
</div>
{/if}
</div>
@@ -0,0 +1,360 @@
<script lang="ts">
import { models } from "../../stores/api";
import { persistentStore } from "../../stores/persistent";
import { generateSpeech } from "../../lib/speechApi";
import { playgroundStores } from "../../stores/playgroundActivity";
import ModelSelector from "./ModelSelector.svelte";
import ExpandableTextarea from "./ExpandableTextarea.svelte";
const selectedModelStore = persistentStore<string>("playground-speech-model", "");
const selectedVoiceStore = persistentStore<string>("playground-speech-voice", "coral");
const autoPlayStore = persistentStore<boolean>("playground-speech-autoplay", false);
let inputText = $state("");
let isGenerating = $state(false);
let generatedAudioUrl = $state<string | null>(null);
let generatedVoice = $state<string | null>(null);
let generatedTimestamp = $state<Date | null>(null);
let error = $state<string | null>(null);
let abortController = $state<AbortController | null>(null);
let audioElement = $state<HTMLAudioElement | null>(null);
let availableVoices = $state<string[]>(["coral", "alloy", "echo", "fable", "onyx", "nova", "shimmer"]);
let isLoadingVoices = $state(false);
const defaultVoices = ["coral", "alloy", "echo", "fable", "onyx", "nova", "shimmer"];
const CACHE_KEY = "playground-speech-voices-cache";
function getVoicesCache(): Record<string, string[]> {
if (typeof window === "undefined") return {};
try {
const saved = localStorage.getItem(CACHE_KEY);
return saved ? JSON.parse(saved) : {};
} catch {
return {};
}
}
function saveVoicesCache(cache: Record<string, string[]>) {
if (typeof window === "undefined") return;
try {
localStorage.setItem(CACHE_KEY, JSON.stringify(cache));
} catch (e) {
console.error("Error saving voices cache", e);
}
}
let hasModels = $derived($models.some((m) => !m.unlisted));
let isInitialLoad = $state(true);
$effect(() => {
playgroundStores.speechGenerating.set(isGenerating);
});
// On page load, restore cached voices for the selected model if available
$effect(() => {
const model = $selectedModelStore;
if (isInitialLoad) {
isInitialLoad = false;
// If we have cached voices for this model, use them
const cache = getVoicesCache();
if (model && cache[model]) {
availableVoices = cache[model];
}
}
});
async function refreshVoices() {
const model = $selectedModelStore;
if (!model || isLoadingVoices) return;
isLoadingVoices = true;
try {
const response = await fetch(`/v1/audio/voices?model=${encodeURIComponent(model)}`);
if (!response.ok) {
// Fall back to default voices if API call fails
availableVoices = defaultVoices;
const cache = getVoicesCache();
cache[model] = defaultVoices;
saveVoicesCache(cache);
selectedVoiceStore.set(defaultVoices[0]);
return;
}
const data = await response.json();
// Expect response to be an array of voice strings or an object with a voices array
const voices = Array.isArray(data) ? data : (data.voices || defaultVoices);
const newVoices = voices.length > 0 ? voices : defaultVoices;
availableVoices = newVoices;
const cache = getVoicesCache();
cache[model] = newVoices;
saveVoicesCache(cache);
// Reset to first available voice
selectedVoiceStore.set(newVoices[0]);
} catch {
// Fall back to default voices on error
availableVoices = defaultVoices;
const cache = getVoicesCache();
cache[model] = defaultVoices;
saveVoicesCache(cache);
selectedVoiceStore.set(defaultVoices[0]);
} finally {
isLoadingVoices = false;
}
}
function handleVoiceChange(event: Event) {
const value = (event.target as HTMLSelectElement).value;
if (value === "(refresh)") {
refreshVoices();
} else {
selectedVoiceStore.set(value);
}
}
// Auto-play effect when new audio is generated
$effect(() => {
if (generatedAudioUrl && $autoPlayStore && audioElement) {
audioElement.load();
audioElement.play().catch(() => {
// Ignore auto-play errors (e.g., browser policy blocks)
});
}
});
async function generate() {
const trimmedText = inputText.trim();
if (!trimmedText || !$selectedModelStore || isGenerating) return;
isGenerating = true;
error = null;
abortController = new AbortController();
try {
const audioBlob = await generateSpeech(
$selectedModelStore,
trimmedText,
$selectedVoiceStore,
abortController.signal
);
// Revoke previous URL to prevent memory leaks
if (generatedAudioUrl) {
URL.revokeObjectURL(generatedAudioUrl);
}
// Create object URL for the audio blob and store metadata
generatedAudioUrl = URL.createObjectURL(audioBlob);
generatedVoice = $selectedVoiceStore;
generatedTimestamp = new Date();
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
// User cancelled
} else {
error = err instanceof Error ? err.message : "An error occurred";
}
} finally {
isGenerating = false;
abortController = null;
}
}
function cancelGeneration() {
abortController?.abort();
}
function clearInput() {
inputText = "";
}
function downloadAudio() {
if (!generatedAudioUrl) return;
const timestamp = (generatedTimestamp || new Date()).toISOString().replace(/[:.]/g, '-').slice(0, -5);
const voice = generatedVoice || 'speech';
const filename = `${voice}-${timestamp}.mp3`;
const a = document.createElement('a');
a.href = generatedAudioUrl;
a.download = filename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
}
function formatTimestamp(date: Date): string {
return date.toLocaleString(undefined, {
month: 'short',
day: 'numeric',
hour: 'numeric',
minute: '2-digit',
hour12: true
});
}
function handleKeyDown(event: KeyboardEvent) {
if (event.key === "Enter" && !event.shiftKey) {
event.preventDefault();
generate();
}
}
</script>
<div class="flex flex-col h-full">
<!-- Model and voice selectors -->
<div class="shrink-0 flex gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} />
<div class="flex gap-2">
<select
class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
value={$selectedVoiceStore}
onchange={handleVoiceChange}
disabled={isGenerating || isLoadingVoices || !$selectedModelStore}
>
{#each availableVoices as voice (voice)}
<option value={voice}>{voice}</option>
{/each}
<option value="(refresh)">(refresh)</option>
</select>
{#if $selectedModelStore && !getVoicesCache()[$selectedModelStore]}
<button
class="btn shrink-0"
onclick={refreshVoices}
disabled={isLoadingVoices}
title={isLoadingVoices ? "Loading voices..." : "Load voices for this model"}
>
{#if isLoadingVoices}
<svg class="w-5 h-5 animate-spin" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
{:else}
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"></path>
</svg>
{/if}
</button>
{/if}
</div>
</div>
<!-- Empty state for no models configured -->
{#if !hasModels}
<div class="flex-1 flex items-center justify-center text-txtsecondary">
<p>No models configured. Add models to your configuration to generate speech.</p>
</div>
{:else}
<!-- Audio display area -->
<div class="shrink-0 mb-4 bg-surface border border-gray-200 dark:border-white/10 rounded p-4 md:p-6">
{#if isGenerating}
<div class="flex items-center justify-center text-txtsecondary py-8">
<div class="text-center">
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
<p>Generating speech...</p>
</div>
</div>
{:else if error}
<div class="flex items-center justify-center py-8">
<div class="text-center text-red-500">
<p class="font-medium">Error</p>
<p class="text-sm mt-1">{error}</p>
</div>
</div>
{:else if generatedAudioUrl}
<div class="flex flex-col gap-4">
<!-- Header with metadata and download -->
<div class="flex items-center justify-between gap-4">
<div class="flex flex-wrap gap-3 text-sm text-txtsecondary">
{#if generatedVoice}
<span class="flex items-center gap-1">
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z"></path>
</svg>
{generatedVoice}
</span>
{/if}
{#if generatedTimestamp}
<span class="flex items-center gap-1">
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z"></path>
</svg>
{formatTimestamp(generatedTimestamp)}
</span>
{/if}
</div>
<button
class="btn shrink-0"
onclick={downloadAudio}
title="Download audio file"
>
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
</svg>
</button>
</div>
<!-- Audio player with larger controls -->
<div class="w-full">
<audio bind:this={audioElement} controls class="w-full h-12 md:h-16">
<source src={generatedAudioUrl} type="audio/mpeg" />
Your browser does not support the audio element.
</audio>
</div>
</div>
{:else}
<div class="flex items-center justify-center text-txtsecondary py-8">
<div class="text-center">
<svg class="w-12 h-12 md:w-16 md:h-16 mx-auto mb-2 opacity-40" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z"></path>
</svg>
<p>Enter text below to convert to speech</p>
</div>
</div>
{/if}
</div>
<!-- Text input area -->
<div class="flex-1 flex flex-col md:flex-row gap-2 min-h-0">
<ExpandableTextarea
bind:value={inputText}
placeholder="Enter text to convert to speech..."
rows={8}
onkeydown={handleKeyDown}
disabled={isGenerating || !$selectedModelStore}
/>
<div class="shrink-0 flex md:flex-col gap-2">
{#if isGenerating}
<button class="btn bg-red-500 hover:bg-red-600 text-white flex-1 md:flex-none" onclick={cancelGeneration}>
Cancel
</button>
{:else}
<button
class="btn bg-primary text-btn-primary-text hover:opacity-90 flex-1 md:flex-none"
onclick={generate}
disabled={!inputText.trim() || !$selectedModelStore}
>
Generate
</button>
<button
class="btn flex-1 md:flex-none"
onclick={clearInput}
disabled={!inputText.trim()}
>
Clear
</button>
<label class="flex items-center justify-center gap-2 text-sm cursor-pointer">
<input
type="checkbox"
bind:checked={$autoPlayStore}
class="cursor-pointer"
/>
Auto-play
</label>
{/if}
</div>
</div>
{/if}
</div>
@@ -1,4 +1,5 @@
@import "tailwindcss";
@import "katex/dist/katex.min.css";
@custom-variant dark (&:where([data-theme=dark], [data-theme=dark] *));
@theme {

Some files were not shown because too many files have changed in this diff Show More